diff --git a/src/tracksdata/array/_graph_array.py b/src/tracksdata/array/_graph_array.py index 72cdec6d..541e741b 100644 --- a/src/tracksdata/array/_graph_array.py +++ b/src/tracksdata/array/_graph_array.py @@ -1,6 +1,6 @@ from collections.abc import Sequence from copy import copy -from typing import TYPE_CHECKING, Any +from typing import Any import numpy as np @@ -11,9 +11,6 @@ from tracksdata.options import get_options from tracksdata.utils._dtypes import polars_dtype_to_numpy_dtype -if TYPE_CHECKING: - from tracksdata.nodes._mask import Mask - def _validate_shape( shape: tuple[int, ...] | None, @@ -346,14 +343,17 @@ def _fill_array(self, time: int, volume_slicing: Sequence[slice], buffer: np.nda np.ndarray The filled buffer. """ + # Local import: avoids the graph <-> nodes package import cycle (importing + # tracksdata.nodes re-enters the partially-initialized graph package). + from tracksdata.nodes._mask import as_mask + subgraph = self._spatial_filter[(slice(time, time), *volume_slicing)] df = subgraph.node_attrs( attr_keys=[self._attr_key, DEFAULT_ATTR_KEYS.MASK], ) for mask, value in zip(df[DEFAULT_ATTR_KEYS.MASK], df[self._attr_key], strict=True): - mask: Mask - mask.paint_buffer(buffer, value, offset=self._offset) + as_mask(mask).paint_buffer(buffer, value, offset=self._offset) def _offset_as_array(self, ndim: int) -> np.ndarray: """Normalize `offset` to a vector for each spatial axis.""" diff --git a/src/tracksdata/edges/_iou_edges.py b/src/tracksdata/edges/_iou_edges.py index c70a3492..ce591b13 100644 --- a/src/tracksdata/edges/_iou_edges.py +++ b/src/tracksdata/edges/_iou_edges.py @@ -1,6 +1,11 @@ from tracksdata.constants import DEFAULT_ATTR_KEYS from tracksdata.edges._generic_edges import GenericFuncEdgeAttrs -from tracksdata.nodes._mask import Mask +from tracksdata.nodes._mask import Mask, as_mask + + +def _mask_iou(source_mask: "Mask | dict", target_mask: "Mask | dict") -> float: + """IoU between two mask attribute values (struct dicts or `Mask` instances).""" + return as_mask(source_mask).iou(as_mask(target_mask)) class IoUEdgeAttr(GenericFuncEdgeAttrs): @@ -22,7 +27,7 @@ def __init__( mask_key: str = DEFAULT_ATTR_KEYS.MASK, ): super().__init__( - func=Mask.iou, + func=_mask_iou, attr_keys=mask_key, output_key=output_key, ) diff --git a/src/tracksdata/edges/_test/test_iou_edges.py b/src/tracksdata/edges/_test/test_iou_edges.py index c9ecdf6d..7d5474b2 100644 --- a/src/tracksdata/edges/_test/test_iou_edges.py +++ b/src/tracksdata/edges/_test/test_iou_edges.py @@ -4,6 +4,7 @@ from tracksdata.constants import DEFAULT_ATTR_KEYS from tracksdata.edges import IoUEdgeAttr +from tracksdata.edges._iou_edges import _mask_iou from tracksdata.graph import RustWorkXGraph from tracksdata.nodes import Mask from tracksdata.options import get_options, options_context @@ -15,7 +16,7 @@ def test_iou_edges_init_default() -> None: assert operator.output_key == "iou_score" assert operator.attr_keys == DEFAULT_ATTR_KEYS.MASK - assert operator.func == Mask.iou + assert operator.func == _mask_iou def test_iou_edges_init_custom() -> None: @@ -24,7 +25,7 @@ def test_iou_edges_init_custom() -> None: assert operator.output_key == "custom_iou" assert operator.attr_keys == "custom_mask" - assert operator.func == Mask.iou + assert operator.func == _mask_iou @pytest.mark.parametrize("n_workers", [1, 2]) diff --git a/src/tracksdata/functional/_test/test_napari.py b/src/tracksdata/functional/_test/test_napari.py index 712cf53d..cdc996c5 100644 --- a/src/tracksdata/functional/_test/test_napari.py +++ b/src/tracksdata/functional/_test/test_napari.py @@ -5,7 +5,7 @@ from tracksdata.constants import DEFAULT_ATTR_KEYS from tracksdata.functional import to_napari_format from tracksdata.graph import RustWorkXGraph -from tracksdata.nodes import MaskDiskAttrs +from tracksdata.nodes import MaskDiskAttrs, as_mask @pytest.mark.parametrize("metadata_shape", [True, False]) @@ -47,7 +47,7 @@ def test_napari_conversion(metadata_shape: bool) -> None: graph.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, dtype=pl.Array(pl.Int64, 6)) masks = graph.node_attrs(attr_keys=[DEFAULT_ATTR_KEYS.MASK])[DEFAULT_ATTR_KEYS.MASK] graph.update_node_attrs( - attrs={DEFAULT_ATTR_KEYS.BBOX: [mask.bbox for mask in masks]}, + attrs={DEFAULT_ATTR_KEYS.BBOX: [as_mask(mask).bbox for mask in masks]}, node_ids=graph.node_ids(), ) diff --git a/src/tracksdata/graph/_base_graph.py b/src/tracksdata/graph/_base_graph.py index 44e2d1d0..34d55b00 100644 --- a/src/tracksdata/graph/_base_graph.py +++ b/src/tracksdata/graph/_base_graph.py @@ -1378,11 +1378,14 @@ def compute_overlaps(self, iou_threshold: float = 0.0) -> None: raise ValueError("iou_threshold must be between 0.0 and 1.0") def _estimate_overlaps(t: int) -> list[list[int, 2]]: + # Local import: avoids the graph <-> nodes package import cycle. + from tracksdata.nodes._mask import as_mask + node_attrs = self.filter(NodeAttr(DEFAULT_ATTR_KEYS.T) == t).node_attrs( attr_keys=[DEFAULT_ATTR_KEYS.NODE_ID, DEFAULT_ATTR_KEYS.MASK], ) node_ids = node_attrs[DEFAULT_ATTR_KEYS.NODE_ID].to_list() - masks = node_attrs[DEFAULT_ATTR_KEYS.MASK].to_list() + masks = [as_mask(m) for m in node_attrs[DEFAULT_ATTR_KEYS.MASK].to_list()] overlaps = [] for i in range(len(masks)): mask_i = masks[i] @@ -1939,8 +1942,11 @@ def to_geff( } if DEFAULT_ATTR_KEYS.MASK in node_attrs.columns: + # Local import: avoids the graph <-> nodes package import cycle. + from tracksdata.nodes._mask import as_mask + node_dict[DEFAULT_ATTR_KEYS.MASK] = construct_var_len_props( - [mask.mask.astype(np.uint64) for mask in node_attrs[DEFAULT_ATTR_KEYS.MASK]] + [as_mask(mask).mask.astype(np.uint64) for mask in node_attrs[DEFAULT_ATTR_KEYS.MASK]] ) edge_dict = {k: {"values": column_to_numpy(v), "missing": None} for k, v in edge_attrs.to_dict().items()} diff --git a/src/tracksdata/graph/_sql_graph.py b/src/tracksdata/graph/_sql_graph.py index 2c703c15..74293a46 100644 --- a/src/tracksdata/graph/_sql_graph.py +++ b/src/tracksdata/graph/_sql_graph.py @@ -23,7 +23,7 @@ from tracksdata.graph._base_graph import BaseGraph from tracksdata.graph.filters._base_filter import BaseFilter from tracksdata.utils._cache import cache_method -from tracksdata.utils._dataframe import unpack_array_attrs, unpickle_bytes_columns +from tracksdata.utils._dataframe import unpack_array_attrs, unpickle_columns from tracksdata.utils._dtypes import ( STRUCT_FIELD_SEP, AttrSchema, @@ -422,7 +422,7 @@ def _read_attr_dataframe(self, query: sa.Select, table: type[DeclarativeBase]) - schema_overrides=self._graph._polars_schema_override(table), ) - df = unpickle_bytes_columns(df) + df = unpickle_columns(df, self._graph._pickled_physical_columns(table)) return self._graph._cast_columns(table, df) def _query_from_attr_keys( @@ -656,6 +656,11 @@ def __init__( self._engine_kwargs = engine_kwargs if engine_kwargs is not None else {} self._engine = sa.create_engine(self._url, **self._engine_kwargs) + # Initialized before `_define_schema`, which (when reloading an existing + # database) reads the attribute schemas to restore pickle column types. + self._node_attr_schemas_cache: dict | None = None + self._edge_attr_schemas_cache: dict | None = None + # Create unique classes for this instance self._define_schema(overwrite=overwrite) @@ -666,8 +671,6 @@ def __init__( self._max_id_per_time = {} self._update_max_id_per_time() - self._node_attr_schemas_cache: dict | None = None - self._edge_attr_schemas_cache: dict | None = None def supports_custom_indices(self) -> bool: return True @@ -686,8 +689,6 @@ class Base(DeclarativeBase): pass if len(metadata.tables) > 0 and not overwrite: - for table in metadata.tables.values(): - self._restore_pickled_column_types(table) for table_name, table in metadata.tables.items(): cls = type( table_name, @@ -699,6 +700,10 @@ class Base(DeclarativeBase): ) setattr(self, table_name, cls) self.Base = Base + # Restore pickle column types only after the ORM classes exist, so + # `_pickled_physical_columns` can read the stored attribute schemas. + for table_name in metadata.tables: + self._restore_pickled_column_types(getattr(self, table_name)) return class Node(Base): @@ -794,7 +799,37 @@ def _attr_schemas_for_table(self, table_class: type[DeclarativeBase]) -> dict[st @staticmethod def _is_pickled_sql_type(column_type: TypeEngine) -> bool: - return isinstance(column_type, sa.PickleType | sa.LargeBinary) + """Whether a SQL column stores pickled Python objects. + + Only ``PickleType`` columns are pickled. Plain ``LargeBinary`` columns + hold raw bytes (e.g. the blosc2-compressed Mask ``data`` leaf) and must + not be unpickled. After reflection both report as ``LargeBinary``, so + :meth:`_restore_pickled_column_types` re-tags the genuinely-pickled ones + as ``PickleType`` before this check is used. + """ + return isinstance(column_type, sa.PickleType) + + def _pickled_physical_columns(self, table_class: type[DeclarativeBase]) -> list[str]: + """Physical column names whose values are pickled (vs stored natively).""" + return [col.name for col in table_class.__table__.columns if self._is_pickled_sql_type(col.type)] + + def _raw_binary_physical_columns(self, table_class: type[DeclarativeBase]) -> set[str]: + """Physical column names that hold raw ``pl.Binary`` bytes (never pickled). + + These are the only blob columns left as ``LargeBinary`` after reflection; + every other blob column is re-tagged as ``PickleType``. + """ + raw: set[str] = set() + for key, schema in self._attr_schemas_for_table(table_class).items(): + if isinstance(schema.dtype, pl.Struct): + raw.update( + flat_col + for flat_col, leaf_dtype in flatten_struct_dtype(key, schema.dtype) + if leaf_dtype == pl.Binary + ) + elif schema.dtype == pl.Binary: + raw.add(key) + return raw @property def __node_attr_schemas(self) -> dict[str, AttrSchema]: @@ -840,9 +875,21 @@ def __edge_attr_schemas(self, schemas: dict[str, AttrSchema]) -> None: self._private_metadata[self._PRIVATE_SQL_EDGE_SCHEMA_STORE_KEY] = encoded_schemas self._edge_attr_schemas_cache = None - def _restore_pickled_column_types(self, table: sa.Table) -> None: - for column in table.columns: - if isinstance(column.type, sa.LargeBinary): + def _restore_pickled_column_types(self, table_class: type[DeclarativeBase]) -> None: + """Restore ``PickleType`` on reflected pickle columns. + + Reflection reports every blob column as ``LargeBinary``, losing the + distinction between genuinely-pickled columns and raw-binary ones. We + consult the stored schema (via :meth:`_raw_binary_physical_columns`) and + re-tag every blob column as ``PickleType`` except the raw-binary ones + (e.g. the Mask ``data`` leaf), which stay ``LargeBinary`` so writes store + their bytes directly. + """ + if table_class.__tablename__ not in (self.Node.__tablename__, self.Edge.__tablename__): + return + raw_binary = self._raw_binary_physical_columns(table_class) + for column in table_class.__table__.columns: + if isinstance(column.type, sa.LargeBinary) and column.name not in raw_binary: column.type = sa.PickleType() def _polars_schema_override(self, table_class: type[DeclarativeBase]) -> SchemaDict: @@ -1357,7 +1404,7 @@ def _get_neighbors( filter_node_ids, self.Node, ) - node_df = unpickle_bytes_columns(node_df) + node_df = unpickle_columns(node_df, self._pickled_physical_columns(self.Node)) node_df = self._cast_columns(self.Node, node_df) if single_node: @@ -1550,7 +1597,7 @@ def node_attrs( connection=session.connection(), schema_overrides=self._polars_schema_override(self.Node), ) - nodes_df = unpickle_bytes_columns(nodes_df) + nodes_df = unpickle_columns(nodes_df, self._pickled_physical_columns(self.Node)) nodes_df = self._cast_columns(self.Node, nodes_df) # Select using logical keys (struct columns are now reconstructed). @@ -1596,7 +1643,7 @@ def edge_attrs( connection=session.connection(), schema_overrides=self._polars_schema_override(self.Edge), ) - edges_df = unpickle_bytes_columns(edges_df) + edges_df = unpickle_columns(edges_df, self._pickled_physical_columns(self.Edge)) edges_df = self._cast_columns(self.Edge, edges_df) if unpack: diff --git a/src/tracksdata/graph/_test/test_graph_backends.py b/src/tracksdata/graph/_test/test_graph_backends.py index 93731fba..65f54cac 100644 --- a/src/tracksdata/graph/_test/test_graph_backends.py +++ b/src/tracksdata/graph/_test/test_graph_backends.py @@ -16,7 +16,8 @@ from tracksdata.graph import BaseGraph, IndexedRXGraph, RustWorkXGraph, SQLGraph from tracksdata.io._numpy_array import from_array from tracksdata.nodes import RegionPropsNodes -from tracksdata.nodes._mask import Mask +from tracksdata.nodes._mask import Mask, as_mask +from tracksdata.utils._dtypes import STRUCT_FIELD_SEP def test_already_existing_keys(graph_backend: BaseGraph) -> None: @@ -1653,10 +1654,8 @@ def build_node_map(graph: BaseGraph) -> dict[tuple[int, tuple[int, ...]], dict[s assert target_row["y"] == pytest.approx(source_row["y"]) assert target_row["x"] == pytest.approx(source_row["x"]) - source_mask = source_row[DEFAULT_ATTR_KEYS.MASK] - target_mask = target_row[DEFAULT_ATTR_KEYS.MASK] - assert isinstance(source_mask, Mask) - assert isinstance(target_mask, Mask) + source_mask = as_mask(source_row[DEFAULT_ATTR_KEYS.MASK]) + target_mask = as_mask(target_row[DEFAULT_ATTR_KEYS.MASK]) np.testing.assert_array_equal(source_mask.mask, target_mask.mask) np.testing.assert_array_equal(source_mask.bbox, target_mask.bbox) @@ -1763,6 +1762,42 @@ def test_sql_graph_mask_update_survives_reload(tmp_path: Path) -> None: np.testing.assert_array_equal(stored_mask.mask, mask_data) +def test_sql_graph_mask_struct_stored_raw_not_pickled(tmp_path: Path) -> None: + """Mask struct's binary `data` leaf is stored as raw BLOB, not double-pickled. + + Regression for storing the (already blosc2-compressed) mask bytes through a + SQLAlchemy ``PickleType`` column, which wrapped them in a second pickle layer. + """ + db_path = tmp_path / "mask_struct.db" + graph = SQLGraph("sqlite", str(db_path)) + graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, Mask.struct_dtype(2)) + + mask_data = np.array([[True, False], [False, True]], dtype=bool) + mask = Mask(mask_data, bbox=np.array([0, 0, 2, 2])) + node_id = graph.add_node({"t": 0, DEFAULT_ATTR_KEYS.MASK: mask.to_struct()}) + + data_col = f"{DEFAULT_ATTR_KEYS.MASK}{STRUCT_FIELD_SEP}{Mask.MASK_DATA_FIELD}" + # The binary leaf must be a raw LargeBinary column, never a PickleType. + assert isinstance(graph.Node.__table__.columns[data_col].type, sa.LargeBinary) + assert not isinstance(graph.Node.__table__.columns[data_col].type, sa.PickleType) + + graph._engine.dispose() + + reloaded = SQLGraph("sqlite", str(db_path)) + # Still raw binary after reflection — not restored to PickleType. + assert isinstance(reloaded.Node.__table__.columns[data_col].type, sa.LargeBinary) + assert not isinstance(reloaded.Node.__table__.columns[data_col].type, sa.PickleType) + + df = reloaded.node_attrs(attr_keys=[DEFAULT_ATTR_KEYS.MASK]) + assert df.schema[DEFAULT_ATTR_KEYS.MASK] == Mask.struct_dtype(2) + restored = as_mask(df[DEFAULT_ATTR_KEYS.MASK].to_list()[0]) + np.testing.assert_array_equal(restored.mask, mask_data) + np.testing.assert_array_equal(restored.bbox, mask.bbox) + + # Struct-field filtering still works against the flat physical bbox columns. + assert reloaded.filter(NodeAttr(DEFAULT_ATTR_KEYS.MASK).struct.field("min_y") == 0).node_ids() == [node_id] + + def test_sql_graph_struct_dtype_survives_reload(tmp_path: Path) -> None: db_path = tmp_path / "struct_graph.db" graph = SQLGraph("sqlite", str(db_path)) diff --git a/src/tracksdata/metrics/_ctc_metrics.py b/src/tracksdata/metrics/_ctc_metrics.py index 5bd3e84f..2fe667ef 100644 --- a/src/tracksdata/metrics/_ctc_metrics.py +++ b/src/tracksdata/metrics/_ctc_metrics.py @@ -166,8 +166,8 @@ def _matching_data( ]: attr_keys = [DEFAULT_ATTR_KEYS.T, tracklet_id_key, *required_attrs] nodes_df = graph.node_attrs(attr_keys=attr_keys) - if use_mask_serialization: - # required by multiprocessing + if use_mask_serialization and nodes_df[DEFAULT_ATTR_KEYS.MASK].dtype == pl.Object: + # required by multiprocessing; struct mask columns are Arrow-native and need no pickling nodes_df = column_to_bytes(nodes_df, DEFAULT_ATTR_KEYS.MASK) labels = {} diff --git a/src/tracksdata/metrics/_matching.py b/src/tracksdata/metrics/_matching.py index 1dca1aed..5b1f21e1 100644 --- a/src/tracksdata/metrics/_matching.py +++ b/src/tracksdata/metrics/_matching.py @@ -144,6 +144,8 @@ def compute_weights( tuple[list[int], list[int], list[int], list[int], list[float]] Matching data: mapped_ref, mapped_comp, rows, cols, weights (IoU values). """ + # Local import: avoids the graph <-> nodes package import cycle. + from tracksdata.nodes._mask import as_mask from tracksdata.utils._dtypes import column_from_bytes # Handle serialized masks if needed @@ -151,18 +153,18 @@ def compute_weights( ref_group = column_from_bytes(ref_group, DEFAULT_ATTR_KEYS.MASK) comp_group = column_from_bytes(comp_group, DEFAULT_ATTR_KEYS.MASK) + # Materialize masks once, struct values decompress on conversion + ref_masks = [as_mask(m) for m in ref_group[DEFAULT_ATTR_KEYS.MASK]] + comp_masks = [as_mask(m) for m in comp_group[DEFAULT_ATTR_KEYS.MASK]] + mapped_ref = [] mapped_comp = [] rows = [] cols = [] weights = [] - for i, (ref_id, ref_mask) in enumerate( - zip(ref_group[reference_graph_key], ref_group[DEFAULT_ATTR_KEYS.MASK], strict=True) - ): - for j, (comp_id, comp_mask) in enumerate( - zip(comp_group[input_graph_key], comp_group[DEFAULT_ATTR_KEYS.MASK], strict=True) - ): + for i, (ref_id, ref_mask) in enumerate(zip(ref_group[reference_graph_key], ref_masks, strict=True)): + for j, (comp_id, comp_mask) in enumerate(zip(comp_group[input_graph_key], comp_masks, strict=True)): # Intersection over reference is used to select the matches inter = ref_mask.intersection(comp_mask) ctc_score = inter / ref_mask.size diff --git a/src/tracksdata/nodes/__init__.py b/src/tracksdata/nodes/__init__.py index b031d4bc..1fb6f4df 100644 --- a/src/tracksdata/nodes/__init__.py +++ b/src/tracksdata/nodes/__init__.py @@ -1,8 +1,8 @@ """Node operators for creating nodes and their respective attributes (e.g. masks) in a graph.""" from tracksdata.nodes._generic_nodes import GenericFuncNodeAttrs -from tracksdata.nodes._mask import Mask, MaskDiskAttrs +from tracksdata.nodes._mask import Mask, MaskDiskAttrs, as_mask from tracksdata.nodes._random import RandomNodes from tracksdata.nodes._regionprops import RegionPropsNodes -__all__ = ["GenericFuncNodeAttrs", "Mask", "MaskDiskAttrs", "RandomNodes", "RegionPropsNodes"] +__all__ = ["GenericFuncNodeAttrs", "Mask", "MaskDiskAttrs", "RandomNodes", "RegionPropsNodes", "as_mask"] diff --git a/src/tracksdata/nodes/_mask.py b/src/tracksdata/nodes/_mask.py index 22a7619f..05c465bb 100644 --- a/src/tracksdata/nodes/_mask.py +++ b/src/tracksdata/nodes/_mask.py @@ -19,6 +19,28 @@ from tracksdata.graph._base_graph import BaseGraph +def _pack_mask_array(mask: NDArray) -> bytes: + """Compress a mask array into a blosc2 cframe.""" + mask = np.ascontiguousarray(mask) + prev_nthreads = blosc2.set_nthreads(1) + # Bypass blosc2 printing overhead by directly creating a schunk and converting it to cframe, + # instead of using blosc2.pack_tensor + schunk = blosc2.SChunk(data=mask) + dtype = mask.dtype.descr if mask.dtype.kind == "V" else mask.dtype.str + schunk.vlmeta["__pack_tensor__"] = ("numpy", mask.shape, dtype) + cframe = schunk.to_cframe() + blosc2.set_nthreads(prev_nthreads) + return cframe + + +def _unpack_mask_array(data: bytes) -> NDArray: + """Decompress a blosc2 cframe into a mask array.""" + prev_nthreads = blosc2.set_nthreads(1) + mask = blosc2.unpack_tensor(data) + blosc2.set_nthreads(prev_nthreads) + return mask + + @lru_cache(maxsize=5) def _nd_sphere( radius: int, @@ -66,20 +88,11 @@ def __init__( def __getstate__(self) -> dict: data_dict = self.__dict__.copy() - prev_nthreads = blosc2.set_nthreads(1) - # Bypass blosc2 printing overhead by directly creating a schunk and converting it to cframe, - # instead of using blosc2.pack_tensor - schunk = blosc2.SChunk(data=self._mask) - dtype = self._mask.dtype.descr if self._mask.dtype.kind == "V" else self._mask.dtype.str - schunk.vlmeta["__pack_tensor__"] = ("numpy", self._mask.shape, dtype) - data_dict["_mask"] = schunk.to_cframe() - blosc2.set_nthreads(prev_nthreads) + data_dict["_mask"] = _pack_mask_array(self._mask) return data_dict def __setstate__(self, state: dict) -> None: - prev_nthreads = blosc2.set_nthreads(1) - state["_mask"] = blosc2.unpack_tensor(state["_mask"]) - blosc2.set_nthreads(prev_nthreads) + state["_mask"] = _unpack_mask_array(state["_mask"]) self.__dict__.update(state) @property @@ -497,6 +510,115 @@ def __eq__(self, other: Any) -> bool: return False return np.array_equal(self.bbox, other.bbox) and np.array_equal(self.mask, other.mask) + MASK_DATA_FIELD = "data" + + @staticmethod + def bbox_struct_fields(ndim: int) -> list[str]: + """ + Names of the bounding box fields of the mask struct attribute. + + Fields follow the ``bbox`` layout (start indices then end indices), + named after the (z), y, x axis convention, e.g. for 2D: + ``["min_y", "min_x", "max_y", "max_x"]``. + + Parameters + ---------- + ndim : int + The number of spatial dimensions (1 to 3). + + Returns + ------- + list[str] + The bounding box field names. + """ + if ndim < 1 or ndim > 3: + raise ValueError(f"Mask struct attributes are only supported for 1D to 3D masks, got ndim={ndim}") + axes = "zyx"[-ndim:] + return [f"min_{a}" for a in axes] + [f"max_{a}" for a in axes] + + @staticmethod + def struct_dtype(ndim: int) -> pl.Struct: + """ + Polars struct dtype used to store a `Mask` as a struct attribute. + + Bounding box coordinates are scalar integer fields so backends can + filter on them natively (e.g. `NodeAttr("mask").struct.field("min_y") > 5`), + while the binary mask is stored blosc2-compressed in the ``data`` field. + + Parameters + ---------- + ndim : int + The number of spatial dimensions (1 to 3). + + Returns + ------- + pl.Struct + The struct dtype, e.g. for 2D: + `pl.Struct({"min_y": Int64, "min_x": Int64, "max_y": Int64, "max_x": Int64, "data": Binary})`. + """ + fields = dict.fromkeys(Mask.bbox_struct_fields(ndim), pl.Int64) + fields[Mask.MASK_DATA_FIELD] = pl.Binary + return pl.Struct(fields) + + def to_struct(self) -> dict[str, Any]: + """ + Convert the mask to a dict matching [struct_dtype][tracksdata.nodes.Mask.struct_dtype]. + + Returns + ------- + dict[str, Any] + Scalar bounding box fields plus the blosc2-compressed mask under ``"data"``. + """ + fields = self.bbox_struct_fields(self._mask.ndim) + value: dict[str, Any] = {f: int(b) for f, b in zip(fields, self._bbox, strict=True)} + value[self.MASK_DATA_FIELD] = _pack_mask_array(self._mask) + return value + + @classmethod + def from_struct(cls, value: dict[str, Any]) -> "Mask": + """ + Reconstruct a mask from a struct attribute value. + + Parameters + ---------- + value : dict[str, Any] + A dict as produced by [to_struct][tracksdata.nodes.Mask.to_struct]. + + Returns + ------- + Mask + The reconstructed mask. + """ + mask = _unpack_mask_array(value[cls.MASK_DATA_FIELD]) + fields = cls.bbox_struct_fields(mask.ndim) + bbox = np.asarray([value[f] for f in fields], dtype=np.int64) + return cls(mask, bbox) + + +def as_mask(value: "Mask | dict[str, Any]") -> Mask: + """ + Coerce a mask attribute value to a `Mask` instance. + + Accepts both the struct representation (dicts, as returned by graph + backends for struct mask attributes) and `Mask` instances + (legacy `pl.Object` mask attributes). + + Parameters + ---------- + value : Mask | dict[str, Any] + The mask attribute value. + + Returns + ------- + Mask + The coerced mask. + """ + if isinstance(value, Mask): + return value + if isinstance(value, dict): + return Mask.from_struct(value) + raise TypeError(f"Cannot interpret {type(value)} as a Mask.") + class MaskDiskAttrs(GenericFuncNodeAttrs): """ @@ -540,7 +662,7 @@ def __init__( center=np.asarray(list(kwargs.values())), radius=radius, image_shape=image_shape, - ), + ).to_struct(), output_key=output_key, attr_keys=attr_keys, batch_size=0, @@ -551,4 +673,4 @@ def _init_node_attrs(self, graph: "BaseGraph") -> None: Validate that the output key exists in the graph. """ if self.output_key not in graph.node_attr_keys(): - graph.add_node_attr_key(self.output_key, pl.Object) + graph.add_node_attr_key(self.output_key, Mask.struct_dtype(len(self._image_shape))) diff --git a/src/tracksdata/nodes/_regionprops.py b/src/tracksdata/nodes/_regionprops.py index 11ad7261..c53e9ac0 100644 --- a/src/tracksdata/nodes/_regionprops.py +++ b/src/tracksdata/nodes/_regionprops.py @@ -139,6 +139,9 @@ def _init_node_attrs(self, graph: BaseGraph, node_attrs: dict[str, Any]) -> None elif np.isscalar(value): dtype = numpy_char_code_to_dtype(value.dtype) if hasattr(value, "dtype") else type(value) graph.add_node_attr_key(key, dtype) + elif isinstance(value, dict): + # struct-valued attributes, e.g. masks stored as `Mask.to_struct()` + graph.add_node_attr_key(key, pl.Series([value]).dtype) elif type(value).__module__ != "builtins": graph.add_node_attr_key(key, pl.Object) else: @@ -306,7 +309,7 @@ def _nodes_per_time( else: attrs[prop] = getattr(obj, prop) - attrs[DEFAULT_ATTR_KEYS.MASK] = Mask(obj.image, obj.bbox) + attrs[DEFAULT_ATTR_KEYS.MASK] = Mask(obj.image, obj.bbox).to_struct() attrs[DEFAULT_ATTR_KEYS.BBOX] = np.asarray(obj.bbox, dtype=int) attrs[DEFAULT_ATTR_KEYS.T] = t diff --git a/src/tracksdata/nodes/_test/test_mask.py b/src/tracksdata/nodes/_test/test_mask.py index 89f9637f..dc017e25 100644 --- a/src/tracksdata/nodes/_test/test_mask.py +++ b/src/tracksdata/nodes/_test/test_mask.py @@ -1,7 +1,7 @@ import numpy as np import pytest -from tracksdata.nodes._mask import Mask, _nd_sphere +from tracksdata.nodes._mask import Mask, _nd_sphere, as_mask def test_mask_init() -> None: @@ -601,3 +601,87 @@ def test_mask_move() -> None: mask.move(offset=np.asarray([-3, 2]), image_shape=(7, 7)) np.testing.assert_array_equal(mask.bbox, [2, 4, 3, 5]) np.testing.assert_array_equal(mask.mask, point) + + +def test_mask_struct_dtype() -> None: + import polars as pl + + dtype_2d = Mask.struct_dtype(2) + assert dtype_2d == pl.Struct( + { + "min_y": pl.Int64, + "min_x": pl.Int64, + "max_y": pl.Int64, + "max_x": pl.Int64, + "data": pl.Binary, + } + ) + + dtype_3d = Mask.struct_dtype(3) + assert [f.name for f in dtype_3d.fields] == [ + "min_z", + "min_y", + "min_x", + "max_z", + "max_y", + "max_x", + "data", + ] + + with pytest.raises(ValueError): + Mask.struct_dtype(4) + + +@pytest.mark.parametrize("ndim", [2, 3]) +def test_mask_struct_roundtrip(ndim: int) -> None: + rng = np.random.default_rng(0) + shape = (3, 4, 5)[:ndim] + mask_data = rng.uniform(size=shape) > 0.5 + bbox = np.concatenate([np.arange(1, ndim + 1), np.arange(1, ndim + 1) + shape]) + + mask = Mask(mask_data, bbox=bbox) + value = mask.to_struct() + + assert isinstance(value, dict) + assert isinstance(value["data"], bytes) + assert value["min_y" if ndim == 2 else "min_z"] == 1 + + restored = Mask.from_struct(value) + assert restored == mask + + +def test_as_mask() -> None: + mask = Mask(np.ones((2, 2), dtype=bool), bbox=np.array([0, 0, 2, 2])) + + assert as_mask(mask) is mask + assert as_mask(mask.to_struct()) == mask + + with pytest.raises(TypeError): + as_mask("not a mask") + + +def test_mask_struct_attr_in_graph(graph_backend) -> None: + """Masks stored as struct attributes round-trip and are filterable by bbox fields.""" + + from tracksdata.attrs import NodeAttr + from tracksdata.constants import DEFAULT_ATTR_KEYS + + graph = graph_backend + graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, Mask.struct_dtype(2)) + + mask_a = Mask(np.ones((2, 2), dtype=bool), bbox=np.array([0, 0, 2, 2])) + mask_b = Mask(np.ones((2, 3), dtype=bool), bbox=np.array([5, 6, 7, 9])) + + node_a = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, DEFAULT_ATTR_KEYS.MASK: mask_a.to_struct()}) + node_b = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, DEFAULT_ATTR_KEYS.MASK: mask_b.to_struct()}) + + df = graph.node_attrs(attr_keys=[DEFAULT_ATTR_KEYS.NODE_ID, DEFAULT_ATTR_KEYS.MASK]) + assert df[DEFAULT_ATTR_KEYS.MASK].dtype == Mask.struct_dtype(2) + + restored = {n: as_mask(v) for n, v in zip(df[DEFAULT_ATTR_KEYS.NODE_ID], df[DEFAULT_ATTR_KEYS.MASK], strict=True)} + assert restored[node_a] == mask_a + assert restored[node_b] == mask_b + + # filtering on a bbox field of the mask struct + filtered = graph.filter(NodeAttr(DEFAULT_ATTR_KEYS.MASK).struct.field("min_y") > 2).node_ids() + assert filtered == [node_b] diff --git a/src/tracksdata/nodes/_test/test_regionprops.py b/src/tracksdata/nodes/_test/test_regionprops.py index 567c62e0..9e8a3d5b 100644 --- a/src/tracksdata/nodes/_test/test_regionprops.py +++ b/src/tracksdata/nodes/_test/test_regionprops.py @@ -4,7 +4,7 @@ from tracksdata.constants import DEFAULT_ATTR_KEYS from tracksdata.graph import RustWorkXGraph -from tracksdata.nodes import Mask, RegionPropsNodes +from tracksdata.nodes import Mask, RegionPropsNodes, as_mask from tracksdata.options import get_options, options_context @@ -278,12 +278,14 @@ def test_regionprops_mask_creation() -> None: assert "shape" in graph.metadata assert graph.metadata["shape"] == labels.shape - # Check that masks were created + # Check that masks were created as struct attributes nodes_df = graph.node_attrs() masks = nodes_df[DEFAULT_ATTR_KEYS.MASK] + assert masks.dtype == Mask.struct_dtype(labels.ndim - 1) - # All masks should be Mask objects + # All masks should convert back to Mask objects for mask in masks: + mask = as_mask(mask) assert isinstance(mask, Mask) assert mask._mask is not None assert mask._bbox is not None diff --git a/src/tracksdata/utils/_dataframe.py b/src/tracksdata/utils/_dataframe.py index a6de0f17..de4f5a18 100644 --- a/src/tracksdata/utils/_dataframe.py +++ b/src/tracksdata/utils/_dataframe.py @@ -1,3 +1,5 @@ +from collections.abc import Collection + import cloudpickle import polars as pl import polars.selectors as cs @@ -29,23 +31,41 @@ def unpack_array_attrs(df: pl.DataFrame) -> pl.DataFrame: return unpack_array_attrs(df) -def unpickle_bytes_columns(df: pl.DataFrame) -> pl.DataFrame: +def unpickle_columns(df: pl.DataFrame, columns: Collection[str]) -> pl.DataFrame: """ - Unpickle bytes columns from the database. + Unpickle pickled bytes columns read from the database. + + Only the columns in *columns* are unpickled. Raw-binary columns (e.g. the + blosc2-compressed ``data`` leaf of a Mask struct attribute) are stored + natively as ``pl.Binary`` and must be left untouched, so callers pass the + explicit set of genuinely-pickled physical columns rather than relying on + all binary columns being pickled. Parameters ---------- df : pl.DataFrame The DataFrame to unpickle the bytes columns from. + columns : Collection[str] + The physical column names that hold pickled values. Returns ------- pl.DataFrame - The DataFrame with the bytes columns unpickled. + The DataFrame with the pickled columns unpickled. """ - df = df.map_columns(cs.binary(), lambda x: x.map_elements(cloudpickle.loads, return_dtype=pl.Object)) - for col, dtype in zip(df.columns, df.dtypes, strict=True): - if isinstance(dtype, pl.Object): + # `columns` lists columns that are *defined* as pickled (SQL ``PickleType``), + # but the runtime dtype is inferred per query result since pickle columns are + # excluded from the polars schema override. A genuinely-pickled column can + # therefore come back as something other than ``pl.Binary`` (e.g. an all-NULL + # result is inferred as ``pl.Null``). Restrict to actual binary columns so + # ``cloudpickle.loads`` is only ever applied to real bytes. + targets = [col for col in columns if col in df.columns and df.schema[col] == pl.Binary] + if not targets: + return df + + df = df.map_columns(cs.by_name(targets), lambda x: x.map_elements(cloudpickle.loads, return_dtype=pl.Object)) + for col in targets: + if isinstance(df.schema[col], pl.Object): try: df = df.with_columns(pl.Series(df[col].to_list()).alias(col)) except Exception: diff --git a/src/tracksdata/utils/_dtypes.py b/src/tracksdata/utils/_dtypes.py index 5c25b710..a245eee0 100644 --- a/src/tracksdata/utils/_dtypes.py +++ b/src/tracksdata/utils/_dtypes.py @@ -392,6 +392,9 @@ def infer_default_value_from_dtype(dtype: pl.DataType) -> Any: # String types pl.String: sa.String, pl.Utf8: sa.String, + # Raw binary blobs are stored as-is (no pickle round-trip), e.g. the + # blosc2-compressed `data` leaf of a Mask struct attribute. + pl.Binary: sa.LargeBinary, } diff --git a/src/tracksdata/utils/_test/test_dataframe.py b/src/tracksdata/utils/_test/test_dataframe.py index 0ac92389..930d74cc 100644 --- a/src/tracksdata/utils/_test/test_dataframe.py +++ b/src/tracksdata/utils/_test/test_dataframe.py @@ -2,7 +2,7 @@ import numpy as np import polars as pl -from tracksdata.utils._dataframe import unpack_array_attrs, unpickle_bytes_columns +from tracksdata.utils._dataframe import unpack_array_attrs, unpickle_columns def test_unpack_array_attrs() -> None: @@ -34,7 +34,7 @@ def test_unpack_array_attrs() -> None: ) -def test_unpickle_bytes_columns_variable_size_arrays() -> None: +def test_unpickle_columns_variable_size_arrays() -> None: """Regression: unpickling a binary column with variable-size numpy arrays must not crash. This reproduces the production bug triggered by GEFF import: masks are stored as raw @@ -45,7 +45,48 @@ def test_unpickle_bytes_columns_variable_size_arrays() -> None: arrays = [np.ones((41, 41), dtype=bool), np.ones((4, 4), dtype=bool)] df = pl.DataFrame({"mask": pl.Series([cloudpickle.dumps(a) for a in arrays], dtype=pl.Binary)}) - result = unpickle_bytes_columns(df) # must not raise SchemaError + result = unpickle_columns(df, ["mask"]) # must not raise SchemaError for actual, expected in zip(result["mask"].to_list(), arrays, strict=False): np.testing.assert_array_equal(actual, expected) + + +def test_unpickle_columns_all_null_column() -> None: + """A pickle column that comes back entirely NULL is inferred as pl.Null, not pl.Binary. + + Such columns must be skipped so ``cloudpickle.loads`` is never applied to non-bytes + values, and the column is returned untouched. + """ + df = pl.DataFrame({"mask": [None, None, None]}) + assert df.schema["mask"] == pl.Null + + result = unpickle_columns(df, ["mask"]) # must not raise + + assert result.schema["mask"] == pl.Null + assert result["mask"].to_list() == [None, None, None] + + +def test_unpickle_columns_leaves_raw_binary_untouched() -> None: + """Only columns named in *columns* are unpickled; raw-binary columns are left as-is. + + Both a genuinely-pickled column and a raw-binary column (e.g. the blosc2-compressed + Mask ``data`` leaf) come back as ``pl.Binary``, so the function must rely on the + explicit *columns* set rather than the dtype to decide what to unpickle. + """ + pickled = [cloudpickle.dumps(np.array([1, 2, 3])), cloudpickle.dumps(np.array([4, 5]))] + raw = [b"\x00raw-bytes-not-pickled\x01", b"\x02another-raw-blob\x03"] + df = pl.DataFrame( + { + "pickled": pl.Series(pickled, dtype=pl.Binary), + "raw": pl.Series(raw, dtype=pl.Binary), + } + ) + + # "raw" is a binary column but is NOT listed as pickled, so it must be skipped. + result = unpickle_columns(df, ["pickled"]) + + for actual, expected in zip(result["pickled"].to_list(), [np.array([1, 2, 3]), np.array([4, 5])], strict=True): + np.testing.assert_array_equal(actual, expected) + + assert result.schema["raw"] == pl.Binary + assert result["raw"].to_list() == raw