diff --git a/backends/arm/_passes/decompose_avg_pool2d_pass.py b/backends/arm/_passes/decompose_avg_pool2d_pass.py index c46a54b0efa..a3fe049b8bb 100644 --- a/backends/arm/_passes/decompose_avg_pool2d_pass.py +++ b/backends/arm/_passes/decompose_avg_pool2d_pass.py @@ -51,7 +51,7 @@ def call_operator(self, op, args, kwargs, meta): full_op, cat_op, avgpool_op, mul_op = get_decomposition(op) x = args[0] - full_kwargs = {"device": x.data.device} + full_kwargs = {"device": x.data.device, "dtype": x.data.dtype} kernel_h, kernel_w = args[1] kernel_size = kernel_h * kernel_w if len(args) > 2 and args[2] is not None: diff --git a/backends/arm/_passes/fuse_equal_placeholders_pass.py b/backends/arm/_passes/fuse_equal_placeholders_pass.py index 37cac8a8c56..f31675b8daf 100644 --- a/backends/arm/_passes/fuse_equal_placeholders_pass.py +++ b/backends/arm/_passes/fuse_equal_placeholders_pass.py @@ -54,15 +54,15 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # ensure we don't merge any special case int48_t tensors with int32_t tensors # since int48_t tensors needs to be instantiated separately. is_special_dtype = node.meta.get(TosaSpecialDtype.meta_key(), None) - t_cpu = tensor.detach().cpu().contiguous() + t_cpu = tensor.cpu().contiguous().flatten().view(dtype=torch.uint8) data_bytes = t_cpu.numpy().tobytes() key = ( is_special_dtype, - str(t_cpu.dtype), - tuple(t_cpu.shape), + str(tensor.dtype), + tuple(tensor.shape), hashlib.sha1(data_bytes, usedforsecurity=False).hexdigest(), ) - hash_buckets[key].append((node, t_cpu)) + hash_buckets[key].append((node, tensor)) # For each bucket with more than one entry, fuse: for nodes_tensors in hash_buckets.values(): diff --git a/backends/arm/_passes/match_arg_dtype_pass.py b/backends/arm/_passes/match_arg_dtype_pass.py index f0aaa0cf5f9..edf06bcf890 100644 --- a/backends/arm/_passes/match_arg_dtype_pass.py +++ b/backends/arm/_passes/match_arg_dtype_pass.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -19,8 +19,9 @@ torch.int32: 4, torch.int64: 5, torch.float16: 6, - torch.float32: 7, - torch.float64: 8, + torch.bfloat16: 7, + torch.float32: 8, + torch.float64: 9, } diff --git a/backends/arm/operators/op_abs.py b/backends/arm/operators/op_abs.py index b5a58136395..b21407591a5 100644 --- a/backends/arm/operators/op_abs.py +++ b/backends/arm/operators/op_abs.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -43,8 +43,8 @@ def define_node( validate_valid_dtype( self.target, [*inputs, output], - [ts.DType.INT32, ts.DType.FP32], - output.tosa_spec, + [ts.DType.INT32, ts.DType.FP32, ts.DType.BF16], + self.tosa_spec, ) attr = ts.TosaSerializerAttribute() diff --git a/backends/arm/operators/op_add.py b/backends/arm/operators/op_add.py index 6c1ff2e1449..3cb0dc8b700 100644 --- a/backends/arm/operators/op_add.py +++ b/backends/arm/operators/op_add.py @@ -1,4 +1,4 @@ -# Copyright 2023-2025 Arm Limited and/or its affiliates. +# Copyright 2023-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -43,8 +43,8 @@ def define_node( validate_valid_dtype( self.target, [*inputs, output], - [ts.DType.INT32, ts.DType.FP32], - output.tosa_spec, + [ts.DType.INT32, ts.DType.FP32, ts.DType.BF16], + self.tosa_spec, ) attr = ts.TosaSerializerAttribute() diff --git a/backends/arm/operators/op_amax.py b/backends/arm/operators/op_amax.py index e4824fb59c2..d698f4ccd72 100644 --- a/backends/arm/operators/op_amax.py +++ b/backends/arm/operators/op_amax.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -42,7 +42,7 @@ def define_node( self.target, [inputs[0], output], [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32], - output.tosa_spec, + self.tosa_spec, ) input = inputs[0] diff --git a/backends/arm/operators/op_amin.py b/backends/arm/operators/op_amin.py index 34d4d37cdeb..c3b1bdd4786 100644 --- a/backends/arm/operators/op_amin.py +++ b/backends/arm/operators/op_amin.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -42,7 +42,7 @@ def define_node( self.target, [inputs[0], output], [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32], - output.tosa_spec, + self.tosa_spec, ) input = inputs[0] diff --git a/backends/arm/operators/op_any.py b/backends/arm/operators/op_any.py index 2a850c0cf52..d2cec0b8ea8 100644 --- a/backends/arm/operators/op_any.py +++ b/backends/arm/operators/op_any.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -37,7 +37,7 @@ def define_node( validate_num_inputs(self.target, inputs, 3) validate_same_dtype(self.target, [inputs[0], output], ts) validate_valid_dtype( - self.target, [inputs[0], output], ts.DType.BOOL, output.tosa_spec + self.target, [inputs[0], output], ts.DType.BOOL, self.tosa_spec ) input_shape = list(inputs[0].shape) diff --git a/backends/arm/operators/op_avg_pool2d.py b/backends/arm/operators/op_avg_pool2d.py index ec9d42915c1..cc180ec47b7 100644 --- a/backends/arm/operators/op_avg_pool2d.py +++ b/backends/arm/operators/op_avg_pool2d.py @@ -1,4 +1,4 @@ -# Copyright 2023-2025 Arm Limited and/or its affiliates. +# Copyright 2023-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -115,14 +115,14 @@ def define_node( ) -> None: validate_num_inputs(self.target, inputs, [3, 4, 5, 6, 7]) validate_same_dtype(self.target, [inputs[0], output], ts) - supported_dtypes = [ts.DType.INT8, ts.DType.FP32] + supported_dtypes = [ts.DType.INT8, ts.DType.FP32, ts.DType.BF16] if self.tosa_spec.support_extension("int16"): supported_dtypes.append(ts.DType.INT16) validate_valid_dtype( self.target, [inputs[0], output], supported_dtypes, - output.tosa_spec, + self.tosa_spec, ) if inputs[0].dtype == ts.DType.INT8 or inputs[0].dtype == ts.DType.INT16: diff --git a/backends/arm/operators/op_bitwise_not.py b/backends/arm/operators/op_bitwise_not.py index ac0f758469d..c6063ee132e 100644 --- a/backends/arm/operators/op_bitwise_not.py +++ b/backends/arm/operators/op_bitwise_not.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -46,7 +46,7 @@ def define_node( self.target, [*inputs, output], [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32], - output.tosa_spec, + self.tosa_spec, ) attr = ts.TosaSerializerAttribute() diff --git a/backends/arm/operators/op_cat.py b/backends/arm/operators/op_cat.py index 71c18530d55..40debd29685 100644 --- a/backends/arm/operators/op_cat.py +++ b/backends/arm/operators/op_cat.py @@ -1,4 +1,4 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -37,17 +37,23 @@ def define_node( inputs: List[TosaArg], output: TosaArg, ) -> None: - supported_dtypes = [ts.DType.BOOL, ts.DType.INT8, ts.DType.INT32, ts.DType.FP32] + supported_dtypes = [ + ts.DType.BOOL, + ts.DType.INT8, + ts.DType.INT32, + ts.DType.FP32, + ts.DType.BF16, + ] if self.tosa_spec.support_extension("int16"): supported_dtypes.append(ts.DType.INT16) validate_num_inputs(self.target, inputs, [1, 2]) - input_tosa_args = [TosaArg(arg, output.tosa_spec) for arg in inputs[0].special] + input_tosa_args = [TosaArg(arg, self.tosa_spec) for arg in inputs[0].special] validate_same_dtype(self.target, [*input_tosa_args, output], ts) validate_valid_dtype( self.target, [*input_tosa_args, output], supported_dtypes, - output.tosa_spec, + self.tosa_spec, ) dim = 0 if len(inputs) < 2 else inputs[1].number diff --git a/backends/arm/operators/op_ceil.py b/backends/arm/operators/op_ceil.py index 27ee81d0abe..80bdc0f86bd 100644 --- a/backends/arm/operators/op_ceil.py +++ b/backends/arm/operators/op_ceil.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -45,8 +45,8 @@ def define_node( validate_valid_dtype( self.target, inputs[0], - ts.DType.FP32, - output.tosa_spec, + [ts.DType.FP32, ts.DType.BF16], + self.tosa_spec, ) attr = ts.TosaSerializerAttribute() diff --git a/backends/arm/operators/op_clamp.py b/backends/arm/operators/op_clamp.py index d90f92f5e4b..faa40d836ef 100644 --- a/backends/arm/operators/op_clamp.py +++ b/backends/arm/operators/op_clamp.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree @@ -6,7 +6,6 @@ from typing import Any, List, Tuple -import numpy as np import torch import tosa_serializer as ts @@ -67,16 +66,7 @@ def cast_type(value: Any) -> int | float: return min_arg, max_arg def _to_bytes(self, value: int | float, dtype: torch.dtype) -> bytes: - if dtype == torch.float32: - return np.frombuffer(np.float32(value).tobytes(), dtype=np.uint8).tolist() - elif dtype == torch.float16: - return np.frombuffer(np.float16(value).tobytes(), dtype=np.uint8).tolist() - elif dtype == torch.int8: - return np.frombuffer(np.int8(value).tobytes(), dtype=np.uint8).tolist() - elif dtype == torch.int16: - return np.frombuffer(np.int16(value).tobytes(), dtype=np.uint8).tolist() - else: - raise ValueError(f"Unsupported dtype for to_bytes: {dtype}") + return torch.full((1,), value, dtype=dtype).view(torch.uint8).numpy().tolist() def define_node( self, @@ -87,14 +77,19 @@ def define_node( ) -> None: validate_num_inputs(self.target, inputs, [2, 3]) validate_same_dtype(self.target, [inputs[0], output], ts) - supported_dtypes = [ts.DType.INT8, ts.DType.FP16, ts.DType.FP32] + supported_dtypes = [ + ts.DType.INT8, + ts.DType.FP16, + ts.DType.BF16, + ts.DType.FP32, + ] if self.tosa_spec.support_extension("int16"): supported_dtypes.append(ts.DType.INT16) validate_valid_dtype( self.target, [inputs[0], output], supported_dtypes, - output.tosa_spec, + self.tosa_spec, ) node_input_dtype = node.meta["val"].dtype diff --git a/backends/arm/operators/op_constant_pad_nd.py b/backends/arm/operators/op_constant_pad_nd.py index 47d11fb5627..ee9f268cdc1 100644 --- a/backends/arm/operators/op_constant_pad_nd.py +++ b/backends/arm/operators/op_constant_pad_nd.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -55,7 +55,7 @@ def define_node( ts.DType.FP32, ts.DType.BOOL, ], - output.tosa_spec, + self.tosa_spec, ) if inputs[0].dtype == ts.DType.INT8: diff --git a/backends/arm/operators/op_cos.py b/backends/arm/operators/op_cos.py index e6039730b69..30127fca3b9 100644 --- a/backends/arm/operators/op_cos.py +++ b/backends/arm/operators/op_cos.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -40,7 +40,7 @@ def define_node( validate_num_inputs(self.target, inputs, 1) validate_same_dtype(self.target, [*inputs, output], ts) validate_valid_dtype( - self.target, [*inputs, output], ts.DType.FP32, output.tosa_spec + self.target, [*inputs, output], ts.DType.FP32, self.tosa_spec ) attr = ts.TosaSerializerAttribute() attr.CosAttribute() diff --git a/backends/arm/operators/op_eq.py b/backends/arm/operators/op_eq.py index bd72c9491ca..0268a1d3ec3 100644 --- a/backends/arm/operators/op_eq.py +++ b/backends/arm/operators/op_eq.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -48,9 +48,9 @@ def define_node( self.target, inputs, [ts.DType.INT32, ts.DType.FP32], - output.tosa_spec, + self.tosa_spec, ) - validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec) + validate_valid_dtype(self.target, output, ts.DType.BOOL, self.tosa_spec) attr = ts.TosaSerializerAttribute() attr.EqualAttribute() diff --git a/backends/arm/operators/op_erf.py b/backends/arm/operators/op_erf.py index e642a4059fe..676bac392f0 100644 --- a/backends/arm/operators/op_erf.py +++ b/backends/arm/operators/op_erf.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -43,7 +43,7 @@ def define_node( self.target, [*inputs, output], ts.DType.FP32, - output.tosa_spec, + self.tosa_spec, ) # MI lowering diff --git a/backends/arm/operators/op_exp.py b/backends/arm/operators/op_exp.py index 72e89b6906b..4d11558fe75 100644 --- a/backends/arm/operators/op_exp.py +++ b/backends/arm/operators/op_exp.py @@ -1,4 +1,4 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -44,7 +44,7 @@ def define_node( self.target, [*inputs, output], ts.DType.FP32, - output.tosa_spec, + self.tosa_spec, ) attr = ts.TosaSerializerAttribute() diff --git a/backends/arm/operators/op_floor.py b/backends/arm/operators/op_floor.py index d9f831dfb35..ad0285e6fc7 100644 --- a/backends/arm/operators/op_floor.py +++ b/backends/arm/operators/op_floor.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -46,7 +46,7 @@ def define_node( self.target, inputs[0], ts.DType.FP32, - output.tosa_spec, + self.tosa_spec, ) attr = ts.TosaSerializerAttribute() diff --git a/backends/arm/operators/op_ge.py b/backends/arm/operators/op_ge.py index 754778487e9..e23f63d263d 100644 --- a/backends/arm/operators/op_ge.py +++ b/backends/arm/operators/op_ge.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -48,9 +48,9 @@ def define_node( self.target, inputs, [ts.DType.INT32, ts.DType.FP32], - output.tosa_spec, + self.tosa_spec, ) - validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec) + validate_valid_dtype(self.target, output, ts.DType.BOOL, self.tosa_spec) attr = ts.TosaSerializerAttribute() attr.GreaterEqualAttribute() diff --git a/backends/arm/operators/op_gt.py b/backends/arm/operators/op_gt.py index 2a483f735a7..31751ad0ed3 100644 --- a/backends/arm/operators/op_gt.py +++ b/backends/arm/operators/op_gt.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -48,9 +48,9 @@ def define_node( self.target, inputs, [ts.DType.INT32, ts.DType.FP32], - output.tosa_spec, + self.tosa_spec, ) - validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec) + validate_valid_dtype(self.target, output, ts.DType.BOOL, self.tosa_spec) attr = ts.TosaSerializerAttribute() attr.GreaterAttribute() diff --git a/backends/arm/operators/op_index_select.py b/backends/arm/operators/op_index_select.py index ba2aa03c7ff..a205bc29b53 100644 --- a/backends/arm/operators/op_index_select.py +++ b/backends/arm/operators/op_index_select.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -56,7 +56,7 @@ def define_node( self.target, [inputs[0], output], [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32], - output.tosa_spec, + self.tosa_spec, ) weights, _, indices = inputs diff --git a/backends/arm/operators/op_index_tensor.py b/backends/arm/operators/op_index_tensor.py index cd0809df95b..7fc2aa5f027 100644 --- a/backends/arm/operators/op_index_tensor.py +++ b/backends/arm/operators/op_index_tensor.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -46,7 +46,7 @@ def _get_tensor_info(self, tensor: Node): """ if isinstance(tensor, Node): - dtype, shape, _ = extract_tensor_meta(tensor.meta, self.tosa_spec) + dtype, shape, _ = extract_tensor_meta(tensor.meta) return tensor.name, dtype, shape else: return tensor.name, tensor.dtype, tensor.shape @@ -133,9 +133,7 @@ def define_node( index_nodes = indices.special # Broadcast indices - broadcasted_tensors = tutils.broadcast_tensors( - tosa_graph, index_nodes, self.tosa_spec - ) + broadcasted_tensors = tutils.broadcast_tensors(tosa_graph, index_nodes) # Calculate strides so we can shift indices down the line. values_strides = self._calculate_value_strides(values.shape) diff --git a/backends/arm/operators/op_le.py b/backends/arm/operators/op_le.py index aa6b52b9982..5a49e1de420 100644 --- a/backends/arm/operators/op_le.py +++ b/backends/arm/operators/op_le.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -48,9 +48,9 @@ def define_node( self.target, inputs, [ts.DType.INT32, ts.DType.FP32], - output.tosa_spec, + self.tosa_spec, ) - validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec) + validate_valid_dtype(self.target, output, ts.DType.BOOL, self.tosa_spec) attr = ts.TosaSerializerAttribute() attr.GreaterEqualAttribute() diff --git a/backends/arm/operators/op_log.py b/backends/arm/operators/op_log.py index 565d6d56027..7ec0e6e082e 100644 --- a/backends/arm/operators/op_log.py +++ b/backends/arm/operators/op_log.py @@ -1,4 +1,4 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -41,7 +41,7 @@ def define_node( validate_num_inputs(self.target, inputs, 1) validate_same_dtype(self.target, [*inputs, output], ts) validate_valid_dtype( - self.target, [*inputs, output], ts.DType.FP32, output.tosa_spec + self.target, [*inputs, output], ts.DType.FP32, self.tosa_spec ) attr = ts.TosaSerializerAttribute() attr.LogAttribute() diff --git a/backends/arm/operators/op_logical_not.py b/backends/arm/operators/op_logical_not.py index 695af5f7a26..81b44326ae2 100644 --- a/backends/arm/operators/op_logical_not.py +++ b/backends/arm/operators/op_logical_not.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -47,7 +47,7 @@ def define_node( self.target, [*inputs, output], [ts.DType.BOOL], - output.tosa_spec, + self.tosa_spec, ) attr = ts.TosaSerializerAttribute() diff --git a/backends/arm/operators/op_lt.py b/backends/arm/operators/op_lt.py index 4b2b1a1960b..d51ac28b5b9 100644 --- a/backends/arm/operators/op_lt.py +++ b/backends/arm/operators/op_lt.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -48,9 +48,9 @@ def define_node( self.target, inputs, [ts.DType.INT32, ts.DType.FP32], - output.tosa_spec, + self.tosa_spec, ) - validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec) + validate_valid_dtype(self.target, output, ts.DType.BOOL, self.tosa_spec) attr = ts.TosaSerializerAttribute() attr.GreaterAttribute() diff --git a/backends/arm/operators/op_max_pool2d.py b/backends/arm/operators/op_max_pool2d.py index bee0cc3fb0c..8d4bd4f6635 100644 --- a/backends/arm/operators/op_max_pool2d.py +++ b/backends/arm/operators/op_max_pool2d.py @@ -1,4 +1,4 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -51,7 +51,7 @@ def define_node( self.target, [inputs[0], output], supported_dtypes, - output.tosa_spec, + self.tosa_spec, ) input_tensor = inputs[0] diff --git a/backends/arm/operators/op_maximum.py b/backends/arm/operators/op_maximum.py index d3ab305ea3b..a44e20f6657 100644 --- a/backends/arm/operators/op_maximum.py +++ b/backends/arm/operators/op_maximum.py @@ -1,4 +1,4 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -46,8 +46,8 @@ def define_node( validate_valid_dtype( self.target, [*inputs, output], - [ts.DType.INT32, ts.DType.FP32], - output.tosa_spec, + [ts.DType.INT32, ts.DType.FP32, ts.DType.BF16], + self.tosa_spec, ) attr_maximum = ts.TosaSerializerAttribute() diff --git a/backends/arm/operators/op_minimum.py b/backends/arm/operators/op_minimum.py index 7f72d158d43..3fb9d23ccfd 100644 --- a/backends/arm/operators/op_minimum.py +++ b/backends/arm/operators/op_minimum.py @@ -1,4 +1,4 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -47,8 +47,8 @@ def define_node( validate_valid_dtype( self.target, [*inputs, output], - [ts.DType.INT32, ts.DType.FP32], - output.tosa_spec, + [ts.DType.INT32, ts.DType.FP32, ts.DType.BF16], + self.tosa_spec, ) attr_minimum = ts.TosaSerializerAttribute() diff --git a/backends/arm/operators/op_mul.py b/backends/arm/operators/op_mul.py index 0e10443e523..dc7a2277738 100644 --- a/backends/arm/operators/op_mul.py +++ b/backends/arm/operators/op_mul.py @@ -1,4 +1,4 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -45,7 +45,7 @@ def define_node( self.target, [*inputs, output], [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32], - output.tosa_spec, + self.tosa_spec, ) tosa_graph.addConst([1], ts.DType.INT8, 0, name=f"{output.name}_shift") diff --git a/backends/arm/operators/op_neg.py b/backends/arm/operators/op_neg.py index e0bb408e155..621c9872b31 100644 --- a/backends/arm/operators/op_neg.py +++ b/backends/arm/operators/op_neg.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -66,7 +66,7 @@ def define_node( validate_num_inputs(self.target, inputs, 1) validate_same_dtype(self.target, [*inputs, output], ts) validate_valid_dtype( - self.target, [*inputs, output], supported_dtypes, output.tosa_spec + self.target, [*inputs, output], supported_dtypes, self.tosa_spec ) input_zp, output_zp = get_negate_zero_points( diff --git a/backends/arm/operators/op_permute.py b/backends/arm/operators/op_permute.py index fea0aea9298..b76c427d622 100644 --- a/backends/arm/operators/op_permute.py +++ b/backends/arm/operators/op_permute.py @@ -1,4 +1,4 @@ -# Copyright 2023-2025 Arm Limited and/or its affiliates. +# Copyright 2023-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -123,7 +123,7 @@ def define_node( ts.DType.INT32, ts.DType.FP32, ], - output.tosa_spec, + self.tosa_spec, ) # The permutation vector describes a permutation P in default Pytorch dim_order. diff --git a/backends/arm/operators/op_pow.py b/backends/arm/operators/op_pow.py index 33cbc290d2c..e159fe66bcc 100644 --- a/backends/arm/operators/op_pow.py +++ b/backends/arm/operators/op_pow.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -46,7 +46,7 @@ def define_node( self.target, [*inputs, output], [ts.DType.FP16, ts.DType.FP32], - output.tosa_spec, + self.tosa_spec, ) attr = ts.TosaSerializerAttribute() attr.PowAttribute() diff --git a/backends/arm/operators/op_reciprocal.py b/backends/arm/operators/op_reciprocal.py index 108a4fac0fb..77478edbf21 100644 --- a/backends/arm/operators/op_reciprocal.py +++ b/backends/arm/operators/op_reciprocal.py @@ -1,4 +1,4 @@ -# Copyright 2023-2025 Arm Limited and/or its affiliates. +# Copyright 2023-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -42,7 +42,7 @@ def define_node( validate_num_inputs(self.target, inputs, 1) validate_same_dtype(self.target, [*inputs, output], ts) validate_valid_dtype( - self.target, [*inputs, output], ts.DType.FP32, output.tosa_spec + self.target, [*inputs, output], ts.DType.FP32, self.tosa_spec ) attr = ts.TosaSerializerAttribute() attr.ReciprocalAttribute() diff --git a/backends/arm/operators/op_repeat.py b/backends/arm/operators/op_repeat.py index e44fede736d..29483501e5c 100644 --- a/backends/arm/operators/op_repeat.py +++ b/backends/arm/operators/op_repeat.py @@ -1,4 +1,4 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -50,7 +50,7 @@ def define_node( ts.DType.INT32, ts.DType.FP32, ], - output.tosa_spec, + self.tosa_spec, ) multiples = inputs[1].special diff --git a/backends/arm/operators/op_rshift_tensor.py b/backends/arm/operators/op_rshift_tensor.py index 0b5717aa403..8971622b9a7 100644 --- a/backends/arm/operators/op_rshift_tensor.py +++ b/backends/arm/operators/op_rshift_tensor.py @@ -1,4 +1,4 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -41,7 +41,7 @@ def define_node( self.target, [*inputs, output], [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32], - output.tosa_spec, + self.tosa_spec, ) attr = ts.TosaSerializerAttribute() diff --git a/backends/arm/operators/op_rsqrt.py b/backends/arm/operators/op_rsqrt.py index a86eaa40985..a98ace6b759 100644 --- a/backends/arm/operators/op_rsqrt.py +++ b/backends/arm/operators/op_rsqrt.py @@ -1,4 +1,4 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -42,7 +42,7 @@ def define_node( validate_num_inputs(self.target, inputs, 1) validate_same_dtype(self.target, [*inputs, output], ts) validate_valid_dtype( - self.target, [*inputs, output], ts.DType.FP32, output.tosa_spec + self.target, [*inputs, output], ts.DType.FP32, self.tosa_spec ) attr = ts.TosaSerializerAttribute() attr.RsqrtAttribute() diff --git a/backends/arm/operators/op_sigmoid.py b/backends/arm/operators/op_sigmoid.py index 908544ff00c..24d04be93c2 100644 --- a/backends/arm/operators/op_sigmoid.py +++ b/backends/arm/operators/op_sigmoid.py @@ -1,4 +1,4 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -41,7 +41,7 @@ def define_node( validate_num_inputs(self.target, inputs, 1) validate_same_dtype(self.target, [*inputs, output], ts) validate_valid_dtype( - self.target, [*inputs, output], ts.DType.FP32, output.tosa_spec + self.target, [*inputs, output], ts.DType.FP32, self.tosa_spec ) attr = ts.TosaSerializerAttribute() attr.SigmoidAttribute() diff --git a/backends/arm/operators/op_sin.py b/backends/arm/operators/op_sin.py index faa249917c3..528ff0227ad 100644 --- a/backends/arm/operators/op_sin.py +++ b/backends/arm/operators/op_sin.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -40,7 +40,7 @@ def define_node( validate_num_inputs(self.target, inputs, 1) validate_same_dtype(self.target, [*inputs, output], ts) validate_valid_dtype( - self.target, [*inputs, output], ts.DType.FP32, output.tosa_spec + self.target, [*inputs, output], ts.DType.FP32, self.tosa_spec ) attr = ts.TosaSerializerAttribute() attr.SinAttribute() diff --git a/backends/arm/operators/op_slice.py b/backends/arm/operators/op_slice.py index 21c86e5f7c4..9dac73d84ba 100644 --- a/backends/arm/operators/op_slice.py +++ b/backends/arm/operators/op_slice.py @@ -1,4 +1,4 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -80,7 +80,7 @@ def define_node( ts.DType.INT32, ts.DType.FP32, ], - output.tosa_spec, + self.tosa_spec, ) # See slice_copy_support.py diff --git a/backends/arm/operators/op_sub.py b/backends/arm/operators/op_sub.py index 039a2f6bd68..db85e2d78c7 100644 --- a/backends/arm/operators/op_sub.py +++ b/backends/arm/operators/op_sub.py @@ -1,4 +1,4 @@ -# Copyright 2023-2025 Arm Limited and/or its affiliates. +# Copyright 2023-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -44,7 +44,7 @@ def define_node( self.target, [*inputs, output], [ts.DType.INT32, ts.DType.FP32], - output.tosa_spec, + self.tosa_spec, ) attr = ts.TosaSerializerAttribute() diff --git a/backends/arm/operators/op_sum.py b/backends/arm/operators/op_sum.py index e956359736c..ca841e85a21 100644 --- a/backends/arm/operators/op_sum.py +++ b/backends/arm/operators/op_sum.py @@ -1,4 +1,4 @@ -# Copyright 2023-2025 Arm Limited and/or its affiliates. +# Copyright 2023-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -44,7 +44,7 @@ def define_node( self.target, [inputs[0], output], [ts.DType.INT32, ts.DType.FP32], - output.tosa_spec, + self.tosa_spec, ) tensor = inputs[0] diff --git a/backends/arm/operators/op_tanh.py b/backends/arm/operators/op_tanh.py index c4603e90118..ecc7a261816 100644 --- a/backends/arm/operators/op_tanh.py +++ b/backends/arm/operators/op_tanh.py @@ -1,4 +1,4 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -42,7 +42,7 @@ def define_node( validate_num_inputs(self.target, inputs, 1) validate_same_dtype(self.target, [*inputs, output], ts) validate_valid_dtype( - self.target, [*inputs, output], ts.DType.FP32, output.tosa_spec + self.target, [*inputs, output], ts.DType.FP32, self.tosa_spec ) attr = ts.TosaSerializerAttribute() attr.TanhAttribute() diff --git a/backends/arm/operators/op_tosa_matmul.py b/backends/arm/operators/op_tosa_matmul.py index 993caff9867..777b450abfb 100644 --- a/backends/arm/operators/op_tosa_matmul.py +++ b/backends/arm/operators/op_tosa_matmul.py @@ -1,4 +1,4 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the @@ -58,7 +58,7 @@ def define_node( self.target, [*inputs], supported_input_dtypes, - output.tosa_spec, + self.tosa_spec, ) supported_output_dtypes = [ts.DType.INT32, ts.DType.FP32] if self.tosa_spec.support_extension("int16"): @@ -67,7 +67,7 @@ def define_node( self.target, [output], supported_output_dtypes, - output.tosa_spec, + self.tosa_spec, ) # We need to get the zero points and add an intermediate tensor for INT16 case diff --git a/backends/arm/operators/op_tosa_rescale.py b/backends/arm/operators/op_tosa_rescale.py index ae87dcc9c31..77f3569bbac 100644 --- a/backends/arm/operators/op_tosa_rescale.py +++ b/backends/arm/operators/op_tosa_rescale.py @@ -1,4 +1,4 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -235,8 +235,8 @@ def define_node( if ( input_dtype not in [ - map_dtype(torch.int8, self.tosa_spec), - map_dtype(torch.int16, self.tosa_spec), + map_dtype(torch.int8), + map_dtype(torch.int16), ] and input_zp != 0 ): diff --git a/backends/arm/operators/op_tosa_resize.py b/backends/arm/operators/op_tosa_resize.py index e7e63f155d3..f40f3283eec 100644 --- a/backends/arm/operators/op_tosa_resize.py +++ b/backends/arm/operators/op_tosa_resize.py @@ -1,4 +1,4 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -46,7 +46,7 @@ def define_node( self.target, [inputs[0]], supported_input_dtypes, - output.tosa_spec, + self.tosa_spec, ) supported_output_dtypes = [ts.DType.FP32] if node.kwargs.get("resize_mode") == "bilinear": @@ -63,7 +63,7 @@ def define_node( if self.tosa_spec.support_extension("int16"): supported_output_dtypes.append(ts.DType.INT16) validate_valid_dtype( - self.target, [output], supported_output_dtypes, output.tosa_spec + self.target, [output], supported_output_dtypes, self.tosa_spec ) # tosa_shape output is NHWC, take HW input_size_yx = tuple([inputs[0].shape[dim] for dim in inputs[0].dim_order])[ diff --git a/backends/arm/operators/op_tosa_scatter.py b/backends/arm/operators/op_tosa_scatter.py index f0308e024ef..904965dfe94 100644 --- a/backends/arm/operators/op_tosa_scatter.py +++ b/backends/arm/operators/op_tosa_scatter.py @@ -46,7 +46,7 @@ def define_node( ts.DType.FP32, ts.DType.FP16, ], - output.tosa_spec, + self.tosa_spec, ) attr = ts.TosaSerializerAttribute() diff --git a/backends/arm/operators/op_tosa_table.py b/backends/arm/operators/op_tosa_table.py index 7448898bddc..2711d18a98b 100644 --- a/backends/arm/operators/op_tosa_table.py +++ b/backends/arm/operators/op_tosa_table.py @@ -1,4 +1,4 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -43,10 +43,10 @@ def define_node( supported_output_dtypes.append(ts.DType.INT32) validate_valid_dtype( - self.target, inputs, supported_input_dtypes, output.tosa_spec + self.target, inputs, supported_input_dtypes, self.tosa_spec ) validate_valid_dtype( - self.target, output, supported_output_dtypes, output.tosa_spec + self.target, output, supported_output_dtypes, self.tosa_spec ) # The name of the table constant is a bit complex. diff --git a/backends/arm/operators/op_tosa_transpose.py b/backends/arm/operators/op_tosa_transpose.py index c5aa66a85fd..ada2e9820ae 100644 --- a/backends/arm/operators/op_tosa_transpose.py +++ b/backends/arm/operators/op_tosa_transpose.py @@ -1,4 +1,4 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -53,8 +53,9 @@ def define_node( ts.DType.INT32, ts.DType.FP16, ts.DType.FP32, + ts.DType.BF16, ], - output.tosa_spec, + self.tosa_spec, ) output_rank = len(output.shape) diff --git a/backends/arm/operators/op_view.py b/backends/arm/operators/op_view.py index a32cb3aac06..7983afbe786 100644 --- a/backends/arm/operators/op_view.py +++ b/backends/arm/operators/op_view.py @@ -1,4 +1,4 @@ -# Copyright 2023-2025 Arm Limited and/or its affiliates. +# Copyright 2023-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -50,7 +50,7 @@ def define_node( ts.DType.FP32, ts.DType.BOOL, ], - output.tosa_spec, + self.tosa_spec, ) tosa_graph = cast(ts.TosaSerializer, tosa_graph) diff --git a/backends/arm/operators/op_where.py b/backends/arm/operators/op_where.py index f0b6538ac27..7e183e2bd1e 100644 --- a/backends/arm/operators/op_where.py +++ b/backends/arm/operators/op_where.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -43,24 +43,24 @@ def define_node( ) -> None: supported_dtypes = [ts.DType.BOOL] - if output.tosa_spec.support_integer(): + if self.tosa_spec.support_integer(): supported_dtypes += [ ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ] - if output.tosa_spec.support_float(): + if self.tosa_spec.support_float(): supported_dtypes += [ts.DType.FP16, ts.DType.FP32] validate_num_inputs(self.target, inputs, 3) # Not first input, which is condition tensor. validate_same_dtype(self.target, inputs[1:], ts) - validate_valid_dtype(self.target, inputs[0], ts.DType.BOOL, output.tosa_spec) + validate_valid_dtype(self.target, inputs[0], ts.DType.BOOL, self.tosa_spec) validate_valid_dtype( self.target, [*inputs[1:], output], supported_dtypes, - output.tosa_spec, + self.tosa_spec, ) attr = ts.TosaSerializerAttribute() diff --git a/backends/arm/operators/op_while.py b/backends/arm/operators/op_while.py index b4ac4f4f6f1..6c2bfcac6c9 100644 --- a/backends/arm/operators/op_while.py +++ b/backends/arm/operators/op_while.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -76,9 +76,7 @@ def define_node( ) in zip(outputs_needing_tensors, output_dim_orders, strict=True): tensor_name = output_needing_tensor.name + "_dummy" shape = output_needing_tensor.meta["val"].shape - dtype = map_dtype( - output_needing_tensor.meta["val"].dtype, self.tosa_spec - ) + dtype = map_dtype(output_needing_tensor.meta["val"].dtype) tosa_graph.currRegion.currBasicBlock.addTensor( tensor_name, diff --git a/backends/arm/operators/operator_validation_utils.py b/backends/arm/operators/operator_validation_utils.py index 20ee10534d0..6b3271ee8e4 100644 --- a/backends/arm/operators/operator_validation_utils.py +++ b/backends/arm/operators/operator_validation_utils.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -127,7 +127,7 @@ def validate_valid_dtype( self.target, [*inputs, output], [ts.DType.INT8, ts.DType.INT32], - output.tosa_spec, + self.tosa_spec, ) """ diff --git a/backends/arm/operators/ops_binary.py b/backends/arm/operators/ops_binary.py index 3e8cda76b5a..15f07db2e1b 100644 --- a/backends/arm/operators/ops_binary.py +++ b/backends/arm/operators/ops_binary.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -52,7 +52,7 @@ def define_node( self.target, [*inputs, output], [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32], - output.tosa_spec, + self.tosa_spec, ) if self.target in [ "aten.logical_and.default", @@ -63,7 +63,7 @@ def define_node( self.target, [*inputs, output], [ts.DType.BOOL], - output.tosa_spec, + self.tosa_spec, ) attr = ts.TosaSerializerAttribute() attr_builder(attr) diff --git a/backends/arm/operators/ops_identity.py b/backends/arm/operators/ops_identity.py index 0930d7e7997..21c524f2644 100644 --- a/backends/arm/operators/ops_identity.py +++ b/backends/arm/operators/ops_identity.py @@ -49,7 +49,7 @@ def define_node( ts.DType.INT16, ts.DType.INT32, ] - if output.tosa_spec.support_float(): + if self.tosa_spec.support_float(): supported_dtypes += [ts.DType.FP32] if self.tosa_spec.support_extension("int16"): supported_dtypes += [ts.DType.INT48] @@ -59,7 +59,7 @@ def define_node( self.target, [inputs[0], output], supported_dtypes, - output.tosa_spec, + self.tosa_spec, ) # Simply add an identityOp diff --git a/backends/arm/process_node.py b/backends/arm/process_node.py index b85b1b43013..ac05dbcdc04 100644 --- a/backends/arm/process_node.py +++ b/backends/arm/process_node.py @@ -11,6 +11,7 @@ import torch import torch.fx import tosa_serializer as ts + from executorch.backends.arm.operators.node_visitor import NodeVisitor from executorch.backends.arm.tosa.mapping import TosaArg from executorch.backends.arm.tosa.specification import TosaSpecification @@ -26,6 +27,23 @@ from torch.export.exported_program import ExportedProgram +def _tensor_to_numpy_with_dim_order( + tensor: torch.Tensor, dim_order: tuple[int, ...] +) -> np.ndarray: + tensor = tensor.detach().cpu().contiguous() + if tensor.dtype == torch.bfloat16: + try: + import ml_dtypes # type: ignore[import-not-found] + except ImportError as e: + raise RuntimeError( + "ml_dtypes is required to serialize bfloat16 tensors for TOSA. Have you run setup.sh?" + ) from e + np_tensor = tensor.view(torch.uint16).numpy().view(ml_dtypes.bfloat16) + else: + np_tensor = tensor.numpy() + return np.transpose(np_tensor, dim_order) + + def process_call_function( node: torch.fx.Node, tosa_graph: Any, @@ -114,9 +132,9 @@ def process_inputs_to_parameters( f"Expected parameter '{node.name}' to be a torch.Tensor, got " f"{type(parameter_data).__name__}" ) - parameter_values = parameter_data.detach().numpy() - - parameter_values = np.transpose(parameter_values, tosa_arg.dim_order) + parameter_values = _tensor_to_numpy_with_dim_order( + parameter_data, tosa_arg.dim_order # type: ignore[arg-type] + ) tosa_graph.addConst( parameter_values.shape, tosa_arg.dtype, parameter_values, name=tosa_arg.name @@ -144,12 +162,7 @@ def process_inputs_to_buffers( f"Expected buffer '{node.name}' to be a torch.Tensor, got " f"{type(buffer_data).__name__}" ) - buffer_values = buffer_data.detach().numpy() - - # TODO: fragile code for temporary fix - # the mean and var tensors are also stored here but they have shape (1, ) - # we only transpose weights here - buffer_values = np.transpose(buffer_values, tosa_arg.dim_order) + buffer_values = _tensor_to_numpy_with_dim_order(buffer_data, tosa_arg.dim_order) # type: ignore[arg-type] tosa_graph.addConst( buffer_values.shape, tosa_arg.dtype, buffer_values, name=tosa_arg.name @@ -170,8 +183,10 @@ def process_inputs_to_lifted_tensor_constants( "Is the original torch function supported?" ) from e tensor = get_lifted_tensor_constant(edge_program, node) - tensor_data = tensor.detach().numpy() # type: ignore[union-attr] - tensor_values = np.transpose(tensor_data, tosa_arg.dim_order) + tensor_values = _tensor_to_numpy_with_dim_order( + tensor, # type: ignore[arg-type] + tosa_arg.dim_order, # type: ignore[arg-type] + ) tosa_graph.addConst( tensor_values.shape, tosa_arg.dtype, tensor_values, name=tosa_arg.name diff --git a/backends/arm/test/ops/test_abs.py b/backends/arm/test/ops/test_abs.py index 9e8ad2e3d03..632ae24d5cd 100644 --- a/backends/arm/test/ops/test_abs.py +++ b/backends/arm/test/ops/test_abs.py @@ -1,6 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -34,14 +34,20 @@ class Abs(torch.nn.Module): "randn_4d": lambda: (torch.randn(1, 2, 3, 4),), "torch_normal": lambda: (torch.normal(mean=0, std=10, size=(2, 3, 4)),), } + test_parameters_bf16 = { + "randn_1d_bf16": lambda: (torch.randn(8, dtype=torch.bfloat16),), + "randn_4d_bf16": lambda: (torch.randn(1, 2, 3, 4, dtype=torch.bfloat16),), + } def forward(self, x): return torch.abs(x) -@common.parametrize("test_data", Abs.test_parameters) +@common.parametrize("test_data", Abs.test_parameters | Abs.test_parameters_bf16) def test_abs_tosa_FP(test_data: torch.Tensor): - pipeline = TosaPipelineFP[input_t1](Abs(), test_data(), aten_op, exir_op) + pipeline = TosaPipelineFP[input_t1]( + Abs(), test_data(), aten_op, exir_op, tosa_extensions=["bf16"] + ) pipeline.run() diff --git a/backends/arm/test/ops/test_add.py b/backends/arm/test/ops/test_add.py index 31c31c3e88a..4ad1f2b1ba0 100644 --- a/backends/arm/test/ops/test_add.py +++ b/backends/arm/test/ops/test_add.py @@ -1,6 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -49,7 +49,7 @@ class Add2(torch.nn.Module): def forward(self, x: torch.Tensor, y: torch.Tensor): return x + y - test_data: list[input_t2] = { + test_data = { "5d_float": lambda: ( torch.FloatTensor([1, 2, 3, 5, 7]), (torch.FloatTensor([2, 1, 2, 1, 10])), @@ -70,6 +70,12 @@ def forward(self, x: torch.Tensor, y: torch.Tensor): torch.randn(1, 10, 20, 30), ), } + test_data_bf16 = { + "4d_big_small_bf16": lambda: ( + (10e10) * torch.randn(1, 10, 20, 30, dtype=torch.bfloat16), + torch.randn(1, 10, 20, 30, dtype=torch.bfloat16), + ), + } class Add3(torch.nn.Module): @@ -89,6 +95,15 @@ def test_add_tensor_tosa_FP(test_data: input_t1): pipeline.run() +@common.parametrize("test_data", Add.test_data) +def test_add_tensor_tosa_FP_bf16(test_data: input_t1): + x = test_data()[0].to(torch.bfloat16) + pipeline = TosaPipelineFP[input_t1]( + Add(), (x,), aten_op, exir_op, tosa_extensions=["bf16"] + ) + pipeline.run() + + @common.parametrize("test_data", Add.test_data) def test_add_tensor_tosa_INT(test_data: input_t1): pipeline = TosaPipelineINT[input_t1](Add(), test_data(), aten_op, exir_op, qtol=0) @@ -152,9 +167,11 @@ def test_add_tensor_u85_INT(test_data: input_t1): pipeline.run() -@common.parametrize("test_data", Add2.test_data) +@common.parametrize("test_data", Add2.test_data | Add2.test_data_bf16) def test_add_tensor_tosa_FP_2(test_data: input_t2): - pipeline = TosaPipelineFP[input_t2](Add2(), test_data(), aten_op, exir_op) + pipeline = TosaPipelineFP[input_t2]( + Add2(), test_data(), aten_op, exir_op, tosa_extensions=["bf16"] + ) pipeline.run() diff --git a/backends/arm/test/ops/test_avg_pool2d.py b/backends/arm/test/ops/test_avg_pool2d.py index 8885d19ddb4..a0791a88d9c 100644 --- a/backends/arm/test/ops/test_avg_pool2d.py +++ b/backends/arm/test/ops/test_avg_pool2d.py @@ -1,6 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -124,8 +124,19 @@ def forward(self, x: torch.Tensor): "becomes_mean_rank5": lambda: (BecomesMeanInToEdge(), (torch.rand(2, 2, 8, 8),)), } +test_modules_bf16 = { + "rand_bf16": lambda: ( + AvgPool2d(4, 2, 0, False), + (torch.rand(1, 16, 50, 32, dtype=torch.bfloat16),), + ), + "kernel_3x3_stride_1_pad_1_bf16": lambda: ( + AvgPool2d((3, 3), (1, 1), 1), + (torch.rand(1, 4, 12, 12, dtype=torch.bfloat16),), + ), +} -@common.parametrize("test_module", test_modules) + +@common.parametrize("test_module", test_modules | test_modules_bf16) def test_avg_pool2d_tosa_FP(test_module): model, input_tensor = test_module() @@ -134,6 +145,7 @@ def test_avg_pool2d_tosa_FP(test_module): input_tensor, aten_op, exir_op, + tosa_extensions=["bf16"], run_on_tosa_ref_model=conftest.is_option_enabled("tosa_ref_model"), ) pipeline.run() diff --git a/backends/arm/test/ops/test_cat.py b/backends/arm/test/ops/test_cat.py index a037d0e366f..8a2f990c277 100644 --- a/backends/arm/test/ops/test_cat.py +++ b/backends/arm/test/ops/test_cat.py @@ -1,6 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -66,6 +66,23 @@ class Cat(torch.nn.Module): ), } + test_parameters_bf16 = { + "cat_rand_two_tensors_bf16": lambda: ( + ( + torch.randn(1, 2, 4, 4, dtype=torch.bfloat16), + torch.randn(1, 2, 4, 1, dtype=torch.bfloat16), + ), + 3, + ), + "cat_rand_dim0_bf16": lambda: ( + ( + torch.randn(1, 2, 4, 4, dtype=torch.bfloat16), + torch.randn(1, 2, 4, 4, dtype=torch.bfloat16), + ), + 0, + ), + } + def __init__(self): super().__init__() @@ -73,19 +90,20 @@ def forward(self, t: tuple[torch.Tensor, ...], dim: int) -> torch.Tensor: return torch.cat(t, dim=dim) -@common.parametrize("test_data", Cat.test_parameters) +@common.parametrize("test_data", Cat.test_parameters | Cat.test_parameters_bf16) def test_cat_tosa_FP(test_data: Tuple): pipeline = TosaPipelineFP[input_t1]( Cat(), test_data(), aten_op, exir_op, + tosa_extensions=["bf16"], ) pipeline.run() def test_cat_tosa_FP_4d(): - square = torch.ones((2, 2, 2, 2)) + square = torch.ones((2, 2, 2, 2), dtype=torch.bfloat16) for dim in range(-3, 3): test_data = ((square, square.clone()), dim) pipeline = TosaPipelineFP[input_t1]( @@ -93,6 +111,7 @@ def test_cat_tosa_FP_4d(): test_data, aten_op, exir_op, + tosa_extensions=["bf16"], ) pipeline.run() diff --git a/backends/arm/test/ops/test_ceil.py b/backends/arm/test/ops/test_ceil.py index 93b5f9cd009..7f03d3ce11f 100644 --- a/backends/arm/test/ops/test_ceil.py +++ b/backends/arm/test/ops/test_ceil.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -42,8 +42,19 @@ def forward(self, x: torch.Tensor): "ceil_ramp": lambda: (Ceil(), ramp), } +test_data_bf16 = { + "ceil_rand_bf16": lambda: ( + Ceil(), + (torch.rand(4, 4, dtype=torch.bfloat16) - 0.5), + ), + "ceil_ramp_bf16": lambda: ( + Ceil(), + torch.arange(-8, 8, 0.25, dtype=torch.bfloat16), + ), +} -@common.parametrize("test_data", test_data) + +@common.parametrize("test_data", test_data | test_data_bf16) def test_ceil_tosa_FP(test_data: input_t1): module, data = test_data() pipeline = TosaPipelineFP[input_t1]( @@ -51,6 +62,7 @@ def test_ceil_tosa_FP(test_data: input_t1): (data,), module.aten_op, module.exir_op, + tosa_extensions=["bf16"], ) pipeline.run() diff --git a/backends/arm/test/ops/test_clamp.py b/backends/arm/test/ops/test_clamp.py index 13c3479daa3..ad01be27cf9 100644 --- a/backends/arm/test/ops/test_clamp.py +++ b/backends/arm/test/ops/test_clamp.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -35,6 +35,19 @@ "rank_4_no_max": lambda: (torch.rand(1, 10, 10, 1) - 3, -3.3, None), } +test_data_suite_bf16 = { + "rank_2_bf16": lambda: ( + torch.rand(1, 35, dtype=torch.bfloat16), + 0.5, + 0.8, + ), + "rank_4_no_max_bf16": lambda: ( + torch.rand(1, 10, 10, 1, dtype=torch.bfloat16) - 3, + -3.3, + None, + ), +} + test_data_suite_int32 = { "int32_rank2": lambda: (torch.randint(-50, 50, (2, 3), dtype=torch.int32), -10, 10), "int32_rank3_no_min": lambda: ( @@ -70,7 +83,7 @@ def forward(self, x): return torch.clamp(x, self.clamp_min, self.clamp_max) -@common.parametrize("test_data", test_data_suite) +@common.parametrize("test_data", test_data_suite | test_data_suite_bf16) def test_clamp_tosa_FP(test_data): input_tensor, min_val, max_val = test_data() model = Clamp(min_val, max_val) @@ -80,6 +93,7 @@ def test_clamp_tosa_FP(test_data): (input_tensor,), aten_op, exir_op, + tosa_extensions=["bf16"], ) pipeline.run() @@ -254,6 +268,19 @@ def test_clamp_vgf_quant(test_data): ), } +test_data_suite_tensor_bf16 = { + "rank_2_bf16": lambda: ( + torch.rand(1, 35, dtype=torch.bfloat16), + torch.tensor(0.5, dtype=torch.bfloat16), + torch.tensor(0.8, dtype=torch.bfloat16), + ), + "rank_4_no_max_bf16": lambda: ( + torch.rand(10, 20, 30, 40, dtype=torch.bfloat16) - 3, + torch.tensor(-0.1, dtype=torch.bfloat16), + None, + ), +} + test_data_suite_tensor_INT32 = { "int32_rank2": lambda: ( torch.randint(-50, 50, (2, 3), dtype=torch.int32), @@ -328,7 +355,9 @@ def test_clamp_vgf_quant(test_data): } -@common.parametrize("test_data", test_data_suite_tensor_FP) +@common.parametrize( + "test_data", test_data_suite_tensor_FP | test_data_suite_tensor_bf16 +) def test_clamp_tosa_FP_tensor(test_data): input_tensor, min_val, max_val = test_data() model = Clamp(min_val, max_val) @@ -338,6 +367,7 @@ def test_clamp_tosa_FP_tensor(test_data): (input_tensor,), aten_op_tensor, exir_op_tensor, + tosa_extensions=["bf16"], ) pipeline.run() diff --git a/backends/arm/test/runner_utils.py b/backends/arm/test/runner_utils.py index 44e9e75d2d6..9aac49f9cb1 100644 --- a/backends/arm/test/runner_utils.py +++ b/backends/arm/test/runner_utils.py @@ -1,4 +1,4 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -60,7 +60,7 @@ torch.float16: np.float16, torch.float32: np.float32, torch.float64: np.float64, - torch.bfloat16: np.float32, + torch.bfloat16: np.uint16, torch.complex32: np.complex64, torch.complex64: np.complex64, torch.complex128: np.complex128, @@ -172,16 +172,17 @@ def get_output_quantization_params( def torch_tensor_to_numpy(tensor: torch.Tensor) -> np.ndarray: - dtype = _torch_to_numpy_dtype_dict[tensor.dtype] - array = tensor.detach().numpy().astype(dtype) # type: ignore[var-annotated] dim_order = tensor.dim_order() if dim_order == NHWC_ORDER: - a = array.transpose(NHWC_ORDER) - return a + tensor = tensor.permute(NHWC_ORDER) elif dim_order == NNHWC_ORDER: - return array.transpose(NNHWC_ORDER) - else: - return array + tensor = tensor.permute(NNHWC_ORDER) + + tensor = tensor.detach() + if tensor.dtype == torch.bfloat16: + # Numpy doesn't support bfloat16, use, uint16 instead. Dtype is inferred from model anyways. + tensor = tensor.view(torch.uint16) + return tensor.numpy() def numpy_to_torch_tensor(array: np.ndarray, output_node: Node) -> torch.Tensor: @@ -197,8 +198,12 @@ def numpy_to_torch_tensor(array: np.ndarray, output_node: Node) -> torch.Tensor: tensor = torch.from_numpy(array).reshape(shape_with_dim_order) return tensor.permute(NNHWC_INVERSE_ORDER).to(memory_format=torch.channels_last) else: - tensor = torch.from_numpy(array).reshape(shape) - return tensor + if type(array.dtype) is np.dtypes.VoidDType: + # If dtype is void, "cheat" and use the output_tensor dtype. + tensor = torch.frombuffer(array, dtype=output_tensor.dtype) + else: + tensor = torch.from_numpy(array) + return tensor.reshape(shape) class TosaReferenceModelDispatch(TorchFunctionMode): diff --git a/backends/arm/test/tester/arm_tester.py b/backends/arm/test/tester/arm_tester.py index 66f90ead413..e12902b6e2b 100644 --- a/backends/arm/test/tester/arm_tester.py +++ b/backends/arm/test/tester/arm_tester.py @@ -1025,7 +1025,7 @@ def _get_dtype_distribution( placeholder_dtypes.append(str(node.meta["val"].dtype)) if node.op == "call_function": if "val" in node.meta and isinstance(node.meta["val"], torch.Tensor): - dtype, _, _ = extract_tensor_meta(node.meta, tosa_spec) + dtype, _, _ = extract_tensor_meta(node.meta) call_function_dtypes.append(ts.DTypeNames[dtype]) return Counter(placeholder_dtypes), Counter(call_function_dtypes) diff --git a/backends/arm/tosa/mapping.py b/backends/arm/tosa/mapping.py index c11a046cd66..494c0120f8b 100644 --- a/backends/arm/tosa/mapping.py +++ b/backends/arm/tosa/mapping.py @@ -11,7 +11,7 @@ import operator from enum import Enum -from typing import Any, Optional, Sequence +from typing import Any, Sequence import torch import tosa_serializer as ts @@ -76,12 +76,11 @@ def min(self): raise ValueError(f"Unrecognized TosaSpecialDtype {self}.") -def map_dtype(data_type: torch.dtype, tosa_spec: TosaSpecification) -> Any: +def map_dtype(data_type: torch.dtype) -> Any: """Map a ``torch.dtype`` to a ``ts.DType``. Args: data_type (torch.dtype): PyTorch dtype to convert. - tosa_spec (TosaSpecification): Active spec (reserved for future checks). Returns: ts.DType: Matching serializer dtype. @@ -114,12 +113,11 @@ def map_dtype(data_type: torch.dtype, tosa_spec: TosaSpecification) -> Any: # Returns the shape and type of a node # TODO: other types, can be # SymInt, FakeTensor, a List[Union[FakeTensor, SymInt]], or None -def extract_tensor_meta(meta, tosa_spec: TosaSpecification): +def extract_tensor_meta(meta): """Extract dtype, shape, and dimension order from FX metadata. Args: meta (dict): FX node ``meta`` containing a ``val`` FakeTensor (or tuple). - tosa_spec (TosaSpecification): Active TOSA spec for dtype mapping. Returns: tuple[ts.DType, tuple[int, ...], tuple[int, ...]]: Tuple containing @@ -140,7 +138,7 @@ def extract_tensor_meta(meta, tosa_spec: TosaSpecification): raise ValueError( f"Expected first value in node.meta['val'] to be FakeTensor, got {val.__class__}" ) - dtype = map_dtype(val.dtype, tosa_spec) + dtype = map_dtype(val.dtype) shape = tuple(val.size()) if meta.get("tosa_dim_order") is not None: @@ -165,7 +163,6 @@ class TosaArg: ``range(len(shape))``. special (list | None): Captured list when the argument is a sequence. number (float | int | None): Captured numeric value when provided. - tosa_spec (TosaSpecification): Active specification used for mapping. multiple_output_name (list[str]): Output node names when node has multiple outputs; empty otherwise. """ @@ -181,9 +178,8 @@ def __process_node(self, argument: torch.fx.Node): if "val" in argument.meta: output_dtype, self.shape, self.dim_order = extract_tensor_meta( - argument.meta, self.tosa_spec - ) - # Handle special case of types not representable in torch (i.e. i48_t) + argument.meta + ) # Handle special case of types not representable in torch (i.e. i48_t) if special_type := argument.meta.get(TosaSpecialDtype.meta_key(), None): output_dtype = special_type.get_tosa_dtype() @@ -199,11 +195,6 @@ def __process_node(self, argument: torch.fx.Node): else: self.multiple_output_names = [] - if not self.__validate(): - raise ValueError( - f"{self.tosa_spec} doesn't support tensor {self.__repr__()}" - ) - def __process_list(self, argument): """Capture a sequence argument as ``special``. @@ -222,20 +213,21 @@ def __process_number(self, argument: float | int): """ self.number: float | int = argument - def __validate(self) -> bool: + def __validate(self, tosa_spec: TosaSpecification) -> bool: match getattr(self, "dtype", None): case ts.DType.FP32: - if not self.tosa_spec.support_float(): + if not tosa_spec.support_float(): return False case ts.DType.INT4: - if not self.tosa_spec.support_extension("int4"): + if not tosa_spec.support_extension("int4"): + return False + case ts.DType.BF16: + if not tosa_spec.support_extension("bf16"): return False return True - def __init__( - self, argument: Any, tosa_spec: Optional[TosaSpecification] = None - ) -> None: + def __init__(self, argument: Any, tosa_spec: TosaSpecification) -> None: """Initialize the argument wrapper and populate fields. Args: @@ -245,20 +237,17 @@ def __init__( required for metadata extraction. Raises: - ValueError: If ``tosa_spec`` is missing or has the wrong type. RuntimeError: If ``argument`` is of an unsupported type. """ - if tosa_spec is None: - raise ValueError("tosa_spec is None") - elif not isinstance(tosa_spec, TosaSpecification): - raise ValueError( - f"Expected tosa_spec to be a TosaSpecification, but got {tosa_spec}" - ) - self.tosa_spec = tosa_spec if isinstance(argument, torch.fx.Node): self.__process_node(argument) + if not self.__validate(tosa_spec): + raise ValueError( + f"{tosa_spec} doesn't support tensor {self.__repr__()}" + ) + return if isinstance(argument, Sequence): self.__process_list(argument) @@ -302,8 +291,7 @@ def __repr__(self): attrs.append(f"special={self.special!r}") if hasattr(self, "number") and self.number is not None: attrs.append(f"number={self.number!r}") - if hasattr(self, "tosa_spec") and self.tosa_spec is not None: - attrs.append(f"tosa_spec={self.tosa_spec!r}") if hasattr(self, "multiple_output_names"): - attrs.append(f"names={self.multiple_output_names!r}") + if len(self.multiple_output_names) > 0: + attrs.append(f"names={self.multiple_output_names!r}") return f"{self.__class__.__name__}({', '.join(attrs)})" diff --git a/backends/arm/tosa/utils.py b/backends/arm/tosa/utils.py index df77153e29f..6666e422039 100644 --- a/backends/arm/tosa/utils.py +++ b/backends/arm/tosa/utils.py @@ -1,4 +1,4 @@ -# Copyright 2023-2025 Arm Limited and/or its affiliates. +# Copyright 2023-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -15,7 +15,6 @@ import tosa_serializer as ts from executorch.backends.arm.tosa.mapping import extract_tensor_meta -from executorch.backends.arm.tosa.specification import TosaSpecification from torch._subclasses.fake_tensor import FakeTensor from torch.fx import Node @@ -64,9 +63,7 @@ def are_fake_tensors_broadcastable( return (True, list(reversed(broadcast_shape))) -def broadcast_tensors( - tosa_fb, nodes: list[Node], tosa_spec: TosaSpecification -) -> list[Any]: +def broadcast_tensors(tosa_fb, nodes: list[Node]) -> list[Any]: """Broadcast the FX nodes to a shared shape inside the TOSA graph. This mirrors ``reshape_for_broadcast`` but also emits the tile operators @@ -96,7 +93,7 @@ def broadcast_tensors( broadcast_tensors = [] for node in nodes: - tens_dtype, tens_shape, _ = extract_tensor_meta(node.meta, tosa_spec) + tens_dtype, tens_shape, _ = extract_tensor_meta(node.meta) list_tens_shape = list(tens_shape) # Already in the right shape we can just add it to the list. if list_tens_shape == common_shape: