Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion src/tracksdata/array/_base_array.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,31 @@
import abc
from typing import Any

import numpy as np
from numpy.typing import ArrayLike

ArrayIndex = ArrayLike | int | slice | tuple[ArrayLike | int | slice, ...]


class BaseReadOnlyArray(abc.ABC):
class BaseReadOnlyArray(np.lib.mixins.NDArrayOperatorsMixin, abc.ABC):
"""
Base class for read-only array-like objects.

Arithmetic and comparison operators (e.g. `array_view == 0`,
`array_view + np.ones(...)`) materialize the array content and
delegate to the corresponding numpy ufunc via `__array_ufunc__`.
"""

# NDArrayOperatorsMixin defines `__eq__`, which would otherwise reset
# `__hash__` to None; keep the default identity hash.
__hash__ = object.__hash__

def __array_ufunc__(self, ufunc: np.ufunc, method: str, *inputs: Any, **kwargs: Any) -> Any:
if kwargs.get("out") is not None:
raise TypeError(f"`out` is not supported for read-only {type(self).__name__}.")
inputs = tuple(np.asarray(x) if isinstance(x, BaseReadOnlyArray) else x for x in inputs)
return getattr(ufunc, method)(*inputs, **kwargs)

def __len__(self) -> int:
"""Returns the length of the first dimension of the array."""
return self.shape[0]
Expand Down
34 changes: 34 additions & 0 deletions src/tracksdata/array/_test/test_graph_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,40 @@ def test_graph_array_view_equal(multi_node_graph_from_image) -> None:
# assert np.array_equal(array_view[0], label[0])


def test_graph_array_view_numpy_operators(multi_node_graph_from_image) -> None:
"""Operators between GraphArrayView and numpy arrays/scalars must cast to the array content."""
array_view, label = multi_node_graph_from_image
t = 1
frame_shape = label.shape[1:]

# comparison with a numpy array, in both operand orders
np.testing.assert_array_equal(array_view[t] == np.zeros(frame_shape), label[t] == 0)
np.testing.assert_array_equal(np.zeros(frame_shape) == array_view[t], label[t] == 0)

# comparison with a scalar must be elementwise, not object identity
result = array_view[t] == t + 1
assert isinstance(result, np.ndarray)
np.testing.assert_array_equal(result, label[t] == t + 1)
np.testing.assert_array_equal(array_view[t] != 0, label[t] != 0)
np.testing.assert_array_equal(array_view[t] > 0, label[t] > 0)

# arithmetic with arrays and scalars
np.testing.assert_array_equal(array_view[t] + np.ones(frame_shape), label[t] + 1.0)
np.testing.assert_array_equal(array_view[t] + 1, label[t] + 1)
np.testing.assert_array_equal(array_view[t] * 2, label[t] * 2)
np.testing.assert_array_equal(-array_view[t], -label[t].astype(np.int64))

# between two views
np.testing.assert_array_equal(array_view[t] == array_view[t], np.ones(frame_shape, dtype=bool))

# `out=` is not supported for read-only arrays
with pytest.raises(TypeError, match="out"):
np.add(array_view[t], 1, out=np.zeros(frame_shape))

# operator support must not break hashability (e.g. signal connections)
assert isinstance(hash(array_view), int)


def test_graph_array_view_getitem_multi_slices(multi_node_graph_from_image) -> None:
"""Test __getitem__ with slices."""
array_view, label = multi_node_graph_from_image
Expand Down
Loading