From b5b429915b445715ff8bbfff876275867596a9cb Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Wed, 1 Jul 2026 14:41:22 -0700 Subject: [PATCH] fix: support numpy operators on GraphArrayView Operations between a GraphArrayView and numpy arrays or scalars did not cast the view to its array content: 'view[t] == 0' silently returned Python's identity False, and 'view[t] + 1' / 'view[t] > 0' raised TypeError. Array operands only worked through numpy's reflected ops. Make BaseReadOnlyArray inherit NDArrayOperatorsMixin and implement __array_ufunc__, materializing read-only array operands before delegating to the ufunc. 'out=' is rejected (read-only) and identity hashing is kept (the mixin's __eq__ would otherwise disable it, breaking signal connections). --- src/tracksdata/array/_base_array.py | 17 +++++++++- .../array/_test/test_graph_array.py | 34 +++++++++++++++++++ 2 files changed, 50 insertions(+), 1 deletion(-) diff --git a/src/tracksdata/array/_base_array.py b/src/tracksdata/array/_base_array.py index 07a94d59..6898804c 100644 --- a/src/tracksdata/array/_base_array.py +++ b/src/tracksdata/array/_base_array.py @@ -1,4 +1,5 @@ import abc +from typing import Any import numpy as np from numpy.typing import ArrayLike @@ -6,11 +7,25 @@ ArrayIndex = ArrayLike | int | slice | tuple[ArrayLike | int | slice, ...] -class BaseReadOnlyArray(abc.ABC): +class BaseReadOnlyArray(np.lib.mixins.NDArrayOperatorsMixin, abc.ABC): """ Base class for read-only array-like objects. + + Arithmetic and comparison operators (e.g. `array_view == 0`, + `array_view + np.ones(...)`) materialize the array content and + delegate to the corresponding numpy ufunc via `__array_ufunc__`. """ + # NDArrayOperatorsMixin defines `__eq__`, which would otherwise reset + # `__hash__` to None; keep the default identity hash. + __hash__ = object.__hash__ + + def __array_ufunc__(self, ufunc: np.ufunc, method: str, *inputs: Any, **kwargs: Any) -> Any: + if kwargs.get("out") is not None: + raise TypeError(f"`out` is not supported for read-only {type(self).__name__}.") + inputs = tuple(np.asarray(x) if isinstance(x, BaseReadOnlyArray) else x for x in inputs) + return getattr(ufunc, method)(*inputs, **kwargs) + def __len__(self) -> int: """Returns the length of the first dimension of the array.""" return self.shape[0] diff --git a/src/tracksdata/array/_test/test_graph_array.py b/src/tracksdata/array/_test/test_graph_array.py index fd839bbf..8543cb33 100644 --- a/src/tracksdata/array/_test/test_graph_array.py +++ b/src/tracksdata/array/_test/test_graph_array.py @@ -285,6 +285,40 @@ def test_graph_array_view_equal(multi_node_graph_from_image) -> None: # assert np.array_equal(array_view[0], label[0]) +def test_graph_array_view_numpy_operators(multi_node_graph_from_image) -> None: + """Operators between GraphArrayView and numpy arrays/scalars must cast to the array content.""" + array_view, label = multi_node_graph_from_image + t = 1 + frame_shape = label.shape[1:] + + # comparison with a numpy array, in both operand orders + np.testing.assert_array_equal(array_view[t] == np.zeros(frame_shape), label[t] == 0) + np.testing.assert_array_equal(np.zeros(frame_shape) == array_view[t], label[t] == 0) + + # comparison with a scalar must be elementwise, not object identity + result = array_view[t] == t + 1 + assert isinstance(result, np.ndarray) + np.testing.assert_array_equal(result, label[t] == t + 1) + np.testing.assert_array_equal(array_view[t] != 0, label[t] != 0) + np.testing.assert_array_equal(array_view[t] > 0, label[t] > 0) + + # arithmetic with arrays and scalars + np.testing.assert_array_equal(array_view[t] + np.ones(frame_shape), label[t] + 1.0) + np.testing.assert_array_equal(array_view[t] + 1, label[t] + 1) + np.testing.assert_array_equal(array_view[t] * 2, label[t] * 2) + np.testing.assert_array_equal(-array_view[t], -label[t].astype(np.int64)) + + # between two views + np.testing.assert_array_equal(array_view[t] == array_view[t], np.ones(frame_shape, dtype=bool)) + + # `out=` is not supported for read-only arrays + with pytest.raises(TypeError, match="out"): + np.add(array_view[t], 1, out=np.zeros(frame_shape)) + + # operator support must not break hashability (e.g. signal connections) + assert isinstance(hash(array_view), int) + + def test_graph_array_view_getitem_multi_slices(multi_node_graph_from_image) -> None: """Test __getitem__ with slices.""" array_view, label = multi_node_graph_from_image