From 29177f30247cb94daeb9f2bf51297f694a2818ff Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Tue, 17 Feb 2026 13:16:47 +0900 Subject: [PATCH 01/30] added private metadata machinery --- src/tracksdata/graph/_base_graph.py | 46 +++++++++++++++++-- src/tracksdata/graph/_graph_view.py | 12 ++--- src/tracksdata/graph/_rustworkx_graph.py | 6 +-- src/tracksdata/graph/_sql_graph.py | 6 +-- .../graph/_test/test_graph_backends.py | 21 +++++++++ 5 files changed, 76 insertions(+), 15 deletions(-) diff --git a/src/tracksdata/graph/_base_graph.py b/src/tracksdata/graph/_base_graph.py index 5b3708ad..4340ca1d 100644 --- a/src/tracksdata/graph/_base_graph.py +++ b/src/tracksdata/graph/_base_graph.py @@ -47,6 +47,8 @@ class BaseGraph(abc.ABC): Base class for a graph backend. """ + _PRIVATE_METADATA_PREFIX = "__private_" + node_added = Signal(int) node_removed = Signal(int) @@ -1187,6 +1189,9 @@ def from_other(cls: type[T], other: "BaseGraph", **kwargs) -> T: graph = cls(**kwargs) graph.update_metadata(**other.metadata()) + private_metadata = other._private_metadata() + if private_metadata: + graph._update_metadata(**private_metadata) current_node_attr_schemas = graph._node_attr_schemas() for k, v in other._node_attr_schemas().items(): @@ -1824,7 +1829,6 @@ def to_geff( zarr_format=zarr_format, ) - @abc.abstractmethod def metadata(self) -> dict[str, Any]: """ Return the metadata of the graph. @@ -1841,8 +1845,8 @@ def metadata(self) -> dict[str, Any]: print(metadata["shape"]) ``` """ + return {k: v for k, v in self._metadata().items() if not self._is_private_metadata_key(k)} - @abc.abstractmethod def update_metadata(self, **kwargs) -> None: """ Set or update metadata for the graph. @@ -1859,8 +1863,9 @@ def update_metadata(self, **kwargs) -> None: graph.update_metadata(description="Tracking data from experiment 1") ``` """ + self._validate_public_metadata_keys(kwargs.keys()) + self._update_metadata(**kwargs) - @abc.abstractmethod def remove_metadata(self, key: str) -> None: """ Remove a metadata key from the graph. @@ -1876,6 +1881,41 @@ def remove_metadata(self, key: str) -> None: graph.remove_metadata("shape") ``` """ + self._validate_public_metadata_key(key) + self._remove_metadata(key) + + @classmethod + def _is_private_metadata_key(cls, key: str) -> bool: + return key.startswith(cls._PRIVATE_METADATA_PREFIX) + + def _validate_public_metadata_key(self, key: str) -> None: + if self._is_private_metadata_key(key): + raise ValueError(f"Metadata key '{key}' is reserved for internal use.") + + def _validate_public_metadata_keys(self, keys: Sequence[str]) -> None: + for key in keys: + self._validate_public_metadata_key(key) + + def _private_metadata(self) -> dict[str, Any]: + return {k: v for k, v in self._metadata().items() if self._is_private_metadata_key(k)} + + @abc.abstractmethod + def _metadata(self) -> dict[str, Any]: + """ + Return the full metadata including private keys. + """ + + @abc.abstractmethod + def _update_metadata(self, **kwargs) -> None: + """ + Backend-specific metadata update implementation without public key validation. + """ + + @abc.abstractmethod + def _remove_metadata(self, key: str) -> None: + """ + Backend-specific metadata removal implementation without public key validation. + """ def to_traccuracy_graph(self, array_view_kwargs: dict[str, Any] | None = None) -> "TrackingGraph": """ diff --git a/src/tracksdata/graph/_graph_view.py b/src/tracksdata/graph/_graph_view.py index b9f82ead..c689931d 100644 --- a/src/tracksdata/graph/_graph_view.py +++ b/src/tracksdata/graph/_graph_view.py @@ -847,11 +847,11 @@ def copy(self, **kwargs) -> "GraphView": "Use `detach` to create a new reference-less graph with the same nodes and edges." ) - def metadata(self) -> dict[str, Any]: - return self._root.metadata() + def _metadata(self) -> dict[str, Any]: + return self._root._metadata() - def update_metadata(self, **kwargs) -> None: - self._root.update_metadata(**kwargs) + def _update_metadata(self, **kwargs) -> None: + self._root._update_metadata(**kwargs) - def remove_metadata(self, key: str) -> None: - self._root.remove_metadata(key) + def _remove_metadata(self, key: str) -> None: + self._root._remove_metadata(key) diff --git a/src/tracksdata/graph/_rustworkx_graph.py b/src/tracksdata/graph/_rustworkx_graph.py index ef4a3f4f..229eacc2 100644 --- a/src/tracksdata/graph/_rustworkx_graph.py +++ b/src/tracksdata/graph/_rustworkx_graph.py @@ -1499,13 +1499,13 @@ def edge_id(self, source_id: int, target_id: int) -> int: """ return self.rx_graph.get_edge_data(source_id, target_id)[DEFAULT_ATTR_KEYS.EDGE_ID] - def metadata(self) -> dict[str, Any]: + def _metadata(self) -> dict[str, Any]: return self._graph.attrs - def update_metadata(self, **kwargs) -> None: + def _update_metadata(self, **kwargs) -> None: self._graph.attrs.update(kwargs) - def remove_metadata(self, key: str) -> None: + def _remove_metadata(self, key: str) -> None: self._graph.attrs.pop(key, None) def edge_list(self) -> list[list[int, int]]: diff --git a/src/tracksdata/graph/_sql_graph.py b/src/tracksdata/graph/_sql_graph.py index 985cbdc9..c8ea38ed 100644 --- a/src/tracksdata/graph/_sql_graph.py +++ b/src/tracksdata/graph/_sql_graph.py @@ -1992,19 +1992,19 @@ def remove_edge( raise ValueError(f"Edge {edge_id} does not exist in the graph.") session.commit() - def metadata(self) -> dict[str, Any]: + def _metadata(self) -> dict[str, Any]: with Session(self._engine) as session: result = session.query(self.Metadata).all() return {row.key: row.value for row in result} - def update_metadata(self, **kwargs) -> None: + def _update_metadata(self, **kwargs) -> None: with Session(self._engine) as session: for key, value in kwargs.items(): metadata_entry = self.Metadata(key=key, value=value) session.merge(metadata_entry) session.commit() - def remove_metadata(self, key: str) -> None: + def _remove_metadata(self, key: str) -> None: with Session(self._engine) as session: session.query(self.Metadata).filter(self.Metadata.key == key).delete() session.commit() diff --git a/src/tracksdata/graph/_test/test_graph_backends.py b/src/tracksdata/graph/_test/test_graph_backends.py index d6084cd8..7619188e 100644 --- a/src/tracksdata/graph/_test/test_graph_backends.py +++ b/src/tracksdata/graph/_test/test_graph_backends.py @@ -2511,6 +2511,27 @@ def test_metadata_multiple_dtypes(graph_backend: BaseGraph) -> None: assert "mixed_list" not in retrieved +def test_private_metadata_is_hidden_from_public_apis(graph_backend: BaseGraph) -> None: + private_key = "__private_dtype_map" + + graph_backend._update_metadata(**{private_key: {"x": "float64"}}) + graph_backend.update_metadata(shape=[1, 2, 3]) + + public_metadata = graph_backend.metadata() + assert private_key not in public_metadata + assert public_metadata["shape"] == [1, 2, 3] + + with pytest.raises(ValueError, match="reserved for internal use"): + graph_backend.update_metadata(**{private_key: {"x": "int64"}}) + + with pytest.raises(ValueError, match="reserved for internal use"): + graph_backend.remove_metadata(private_key) + + # Internal APIs can still remove private keys. + graph_backend._remove_metadata(private_key) + assert private_key not in graph_backend._metadata() + + def test_pickle_roundtrip(graph_backend: BaseGraph) -> None: if isinstance(graph_backend, SQLGraph): pytest.skip("SQLGraph does not support pickle roundtrip") From d8292f1c9b01ac75c94316d8a990937be3bd74e2 Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Tue, 17 Feb 2026 14:01:04 +0900 Subject: [PATCH 02/30] before adding private --- src/tracksdata/array/_graph_array.py | 2 +- .../functional/_test/test_napari.py | 2 +- src/tracksdata/graph/__init__.py | 4 +- src/tracksdata/graph/_base_graph.py | 118 +++++++++++------- src/tracksdata/graph/_rustworkx_graph.py | 2 +- .../graph/_test/test_graph_backends.py | 44 +++---- src/tracksdata/io/_test/test_ctc_io.py | 2 +- src/tracksdata/nodes/_regionprops.py | 4 +- .../nodes/_test/test_regionprops.py | 36 +++--- 9 files changed, 123 insertions(+), 91 deletions(-) diff --git a/src/tracksdata/array/_graph_array.py b/src/tracksdata/array/_graph_array.py index 018e7ae6..80418986 100644 --- a/src/tracksdata/array/_graph_array.py +++ b/src/tracksdata/array/_graph_array.py @@ -23,7 +23,7 @@ def _validate_shape( """Helper function to validate the shape argument.""" if shape is None: try: - shape = graph.metadata()["shape"] + shape = graph.metadata["shape"] except KeyError as e: raise KeyError( f"`shape` is required to `{func_name}`. " diff --git a/src/tracksdata/functional/_test/test_napari.py b/src/tracksdata/functional/_test/test_napari.py index 9b4a81dc..712cf53d 100644 --- a/src/tracksdata/functional/_test/test_napari.py +++ b/src/tracksdata/functional/_test/test_napari.py @@ -31,7 +31,7 @@ def test_napari_conversion(metadata_shape: bool) -> None: shape = (2, 10, 22, 32) if metadata_shape: - graph.update_metadata(shape=shape) + graph.metadata.update(shape=shape) arg_shape = None else: arg_shape = shape diff --git a/src/tracksdata/graph/__init__.py b/src/tracksdata/graph/__init__.py index fcf207e2..3906949b 100644 --- a/src/tracksdata/graph/__init__.py +++ b/src/tracksdata/graph/__init__.py @@ -1,10 +1,10 @@ """Graph backends for representing tracking data as directed graphs in memory or on disk.""" -from tracksdata.graph._base_graph import BaseGraph +from tracksdata.graph._base_graph import BaseGraph, MetadataView from tracksdata.graph._graph_view import GraphView from tracksdata.graph._rustworkx_graph import IndexedRXGraph, RustWorkXGraph from tracksdata.graph._sql_graph import SQLGraph InMemoryGraph = RustWorkXGraph -__all__ = ["BaseGraph", "GraphView", "InMemoryGraph", "IndexedRXGraph", "RustWorkXGraph", "SQLGraph"] +__all__ = ["BaseGraph", "GraphView", "InMemoryGraph", "IndexedRXGraph", "MetadataView", "RustWorkXGraph", "SQLGraph"] diff --git a/src/tracksdata/graph/_base_graph.py b/src/tracksdata/graph/_base_graph.py index 4340ca1d..bfc8239b 100644 --- a/src/tracksdata/graph/_base_graph.py +++ b/src/tracksdata/graph/_base_graph.py @@ -42,6 +42,61 @@ T = TypeVar("T", bound="BaseGraph") +class MetadataView(dict[str, Any]): + """Dictionary-like metadata view that syncs mutations back to the graph.""" + + _MISSING = object() + + def __init__(self, graph: "BaseGraph", data: dict[str, Any]) -> None: + super().__init__(data) + self._graph = graph + + def __setitem__(self, key: str, value: Any) -> None: + self._graph._set_public_metadata(**{key: value}) + super().__setitem__(key, value) + + def __delitem__(self, key: str) -> None: + self._graph._remove_public_metadata(key) + super().__delitem__(key) + + def pop(self, key: str, default: Any = _MISSING) -> Any: + self._graph._validate_public_metadata_key(key) + + if key not in self: + if default is self._MISSING: + raise KeyError(key) + return default + + value = super().__getitem__(key) + self._graph._remove_metadata(key) + super().pop(key, None) + return value + + def popitem(self) -> tuple[str, Any]: + key, value = super().popitem() + self._graph._remove_metadata(key) + return key, value + + def clear(self) -> None: + keys = list(self.keys()) + for key in keys: + self._graph._remove_metadata(key) + super().clear() + + def setdefault(self, key: str, default: Any = None) -> Any: + if key in self: + return super().__getitem__(key) + self._graph._set_public_metadata(**{key: default}) + super().__setitem__(key, default) + return default + + def update(self, *args, **kwargs) -> None: + updates = dict(*args, **kwargs) + if updates: + self._graph._set_public_metadata(**updates) + super().update(updates) + + class BaseGraph(abc.ABC): """ Base class for a graph backend. @@ -1188,7 +1243,7 @@ def from_other(cls: type[T], other: "BaseGraph", **kwargs) -> T: node_attrs = node_attrs.drop(DEFAULT_ATTR_KEYS.NODE_ID) graph = cls(**kwargs) - graph.update_metadata(**other.metadata()) + graph.metadata.update(other.metadata) private_metadata = other._private_metadata() if private_metadata: graph._update_metadata(**private_metadata) @@ -1791,7 +1846,7 @@ def to_geff( for k, v in edge_attrs.to_dict().items() } - td_metadata = self.metadata().copy() + td_metadata = self.metadata.copy() td_metadata.pop("geff", None) # avoid geff being written multiple times geff_metadata = geff.GeffMetadata( @@ -1829,66 +1884,35 @@ def to_geff( zarr_format=zarr_format, ) - def metadata(self) -> dict[str, Any]: + @property + def metadata(self) -> MetadataView: """ Return the metadata of the graph. Returns ------- - dict[str, Any] + MetadataView The metadata of the graph as a dictionary. Examples -------- ```python - metadata = graph.metadata() + metadata = graph.metadata print(metadata["shape"]) ``` """ - return {k: v for k, v in self._metadata().items() if not self._is_private_metadata_key(k)} - - def update_metadata(self, **kwargs) -> None: - """ - Set or update metadata for the graph. - - Parameters - ---------- - **kwargs : Any - The metadata items to set by key. Values will be stored as JSON. - - Examples - -------- - ```python - graph.update_metadata(shape=[1, 25, 25], path="path/to/image.ome.zarr") - graph.update_metadata(description="Tracking data from experiment 1") - ``` - """ - self._validate_public_metadata_keys(kwargs.keys()) - self._update_metadata(**kwargs) - - def remove_metadata(self, key: str) -> None: - """ - Remove a metadata key from the graph. - - Parameters - ---------- - key : str - The key of the metadata to remove. - - Examples - -------- - ```python - graph.remove_metadata("shape") - ``` - """ - self._validate_public_metadata_key(key) - self._remove_metadata(key) + return MetadataView( + graph=self, + data={k: v for k, v in self._metadata().items() if not self._is_private_metadata_key(k)}, + ) @classmethod def _is_private_metadata_key(cls, key: str) -> bool: return key.startswith(cls._PRIVATE_METADATA_PREFIX) def _validate_public_metadata_key(self, key: str) -> None: + if not isinstance(key, str): + raise TypeError(f"Metadata key must be a string. Got {type(key)}.") if self._is_private_metadata_key(key): raise ValueError(f"Metadata key '{key}' is reserved for internal use.") @@ -1896,6 +1920,14 @@ def _validate_public_metadata_keys(self, keys: Sequence[str]) -> None: for key in keys: self._validate_public_metadata_key(key) + def _set_public_metadata(self, **kwargs) -> None: + self._validate_public_metadata_keys(kwargs.keys()) + self._update_metadata(**kwargs) + + def _remove_public_metadata(self, key: str) -> None: + self._validate_public_metadata_key(key) + self._remove_metadata(key) + def _private_metadata(self) -> dict[str, Any]: return {k: v for k, v in self._metadata().items() if self._is_private_metadata_key(k)} diff --git a/src/tracksdata/graph/_rustworkx_graph.py b/src/tracksdata/graph/_rustworkx_graph.py index 229eacc2..05cf1c17 100644 --- a/src/tracksdata/graph/_rustworkx_graph.py +++ b/src/tracksdata/graph/_rustworkx_graph.py @@ -371,7 +371,7 @@ def __init__(self, rx_graph: rx.PyDiGraph | None = None) -> None: elif not isinstance(self._graph.attrs, dict): LOG.warning( - "previous attribute %s will be added to key 'old_attrs' of `graph.metadata()`", + "previous attribute %s will be added to key 'old_attrs' of `graph.metadata`", self._graph.attrs, ) self._graph.attrs = { diff --git a/src/tracksdata/graph/_test/test_graph_backends.py b/src/tracksdata/graph/_test/test_graph_backends.py index 7619188e..73a1161c 100644 --- a/src/tracksdata/graph/_test/test_graph_backends.py +++ b/src/tracksdata/graph/_test/test_graph_backends.py @@ -1359,7 +1359,7 @@ def test_from_other_with_edges( ) -> None: """Ensure from_other preserves structure across backend conversions.""" # Create source graph with nodes, edges, and attributes - graph_backend.update_metadata(special_key="special_value") + graph_backend.metadata.update(special_key="special_value") graph_backend.add_node_attr_key("x", dtype=pl.Float64) graph_backend.add_edge_attr_key("weight", dtype=pl.Float64, default_value=-1) @@ -1386,7 +1386,7 @@ def test_from_other_with_edges( assert set(new_graph.node_attr_keys()) == set(graph_backend.node_attr_keys()) assert set(new_graph.edge_attr_keys()) == set(graph_backend.edge_attr_keys()) - assert new_graph.metadata() == graph_backend.metadata() + assert new_graph.metadata == graph_backend.metadata assert new_graph._node_attr_schemas() == graph_backend._node_attr_schemas() assert new_graph._edge_attr_schemas() == graph_backend._edge_attr_schemas() @@ -2322,7 +2322,7 @@ def _fill_mock_geff_graph(graph_backend: BaseGraph) -> None: graph_backend.add_edge_attr_key("weight", pl.Float16) - graph_backend.update_metadata( + graph_backend.metadata.update( shape=[1, 25, 25], path="path/to/image.ome.zarr", ) @@ -2383,11 +2383,11 @@ def test_geff_roundtrip(graph_backend: BaseGraph) -> None: geff_graph, _ = IndexedRXGraph.from_geff(output_store) - assert "geff" in geff_graph.metadata() + assert "geff" in geff_graph.metadata # geff metadata was not stored in original graph - geff_graph.metadata().pop("geff") - assert geff_graph.metadata() == graph_backend.metadata() + geff_graph.metadata.pop("geff") + assert geff_graph.metadata == graph_backend.metadata assert geff_graph.num_nodes() == 3 assert geff_graph.num_edges() == 2 @@ -2442,11 +2442,11 @@ def test_geff_with_keymapping(graph_backend: BaseGraph) -> None: edge_attr_key_map={"weight": "weight_new"}, ) - assert "geff" in geff_graph.metadata() + assert "geff" in geff_graph.metadata # geff metadata was not stored in original graph - geff_graph.metadata().pop("geff") - assert geff_graph.metadata() == graph_backend.metadata() + geff_graph.metadata.pop("geff") + assert geff_graph.metadata == graph_backend.metadata assert geff_graph.num_nodes() == 3 assert geff_graph.num_edges() == 2 @@ -2483,30 +2483,30 @@ def test_metadata_multiple_dtypes(graph_backend: BaseGraph) -> None: } # Update metadata with all test values - graph_backend.update_metadata(**test_metadata) + graph_backend.metadata.update(**test_metadata) # Retrieve and verify - retrieved = graph_backend.metadata() + retrieved = graph_backend.metadata for key, expected_value in test_metadata.items(): assert key in retrieved, f"Key '{key}' not found in metadata" assert retrieved[key] == expected_value, f"Value mismatch for '{key}': {retrieved[key]} != {expected_value}" # Test updating existing keys - graph_backend.update_metadata(string="updated_value", new_key="new_value") - retrieved = graph_backend.metadata() + graph_backend.metadata.update(string="updated_value", new_key="new_value") + retrieved = graph_backend.metadata assert retrieved["string"] == "updated_value" assert retrieved["new_key"] == "new_value" assert retrieved["integer"] == 42 # Other values unchanged # Testing removing metadata - graph_backend.remove_metadata("string") - retrieved = graph_backend.metadata() + graph_backend.metadata.pop("string", None) + retrieved = graph_backend.metadata assert "string" not in retrieved - graph_backend.remove_metadata("mixed_list") - retrieved = graph_backend.metadata() + graph_backend.metadata.pop("mixed_list", None) + retrieved = graph_backend.metadata assert "string" not in retrieved assert "mixed_list" not in retrieved @@ -2515,17 +2515,17 @@ def test_private_metadata_is_hidden_from_public_apis(graph_backend: BaseGraph) - private_key = "__private_dtype_map" graph_backend._update_metadata(**{private_key: {"x": "float64"}}) - graph_backend.update_metadata(shape=[1, 2, 3]) + graph_backend.metadata.update(shape=[1, 2, 3]) - public_metadata = graph_backend.metadata() + public_metadata = graph_backend.metadata assert private_key not in public_metadata assert public_metadata["shape"] == [1, 2, 3] with pytest.raises(ValueError, match="reserved for internal use"): - graph_backend.update_metadata(**{private_key: {"x": "int64"}}) + graph_backend.metadata.update(**{private_key: {"x": "int64"}}) with pytest.raises(ValueError, match="reserved for internal use"): - graph_backend.remove_metadata(private_key) + graph_backend.metadata.pop(private_key, None) # Internal APIs can still remove private keys. graph_backend._remove_metadata(private_key) @@ -2606,7 +2606,7 @@ def test_to_traccuracy_graph(graph_backend: BaseGraph) -> None: graph_backend.add_node_attr_key("y", pl.Float64) graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object) graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, pl.Array(pl.Int64, 4)) - graph_backend.update_metadata(shape=[3, 25, 25]) + graph_backend.metadata.update(shape=[3, 25, 25]) # Create masks for first graph mask1_data = np.array([[True, True], [True, True]], dtype=bool) diff --git a/src/tracksdata/io/_test/test_ctc_io.py b/src/tracksdata/io/_test/test_ctc_io.py index 7c5fb925..01025213 100644 --- a/src/tracksdata/io/_test/test_ctc_io.py +++ b/src/tracksdata/io/_test/test_ctc_io.py @@ -68,7 +68,7 @@ def test_export_from_ctc_roundtrip(tmp_path: Path, metadata_shape: bool) -> None in_graph.add_edge(node_1, node_3, attrs={DEFAULT_ATTR_KEYS.EDGE_DIST: 1.0}) if metadata_shape: - in_graph.update_metadata(shape=(2, 4, 4)) + in_graph.metadata.update(shape=(2, 4, 4)) shape = None else: shape = (2, 4, 4) diff --git a/src/tracksdata/nodes/_regionprops.py b/src/tracksdata/nodes/_regionprops.py index c78feb32..5be49713 100644 --- a/src/tracksdata/nodes/_regionprops.py +++ b/src/tracksdata/nodes/_regionprops.py @@ -230,8 +230,8 @@ def add_nodes( axis_names = self._axis_names(labels) self._init_node_attrs(graph, axis_names, ndims=labels.ndim) - if "shape" not in graph.metadata(): - graph.update_metadata(shape=labels.shape) + if "shape" not in graph.metadata: + graph.metadata.update(shape=labels.shape) if t is None: time_points = range(labels.shape[0]) diff --git a/src/tracksdata/nodes/_test/test_regionprops.py b/src/tracksdata/nodes/_test/test_regionprops.py index 350d231b..567c62e0 100644 --- a/src/tracksdata/nodes/_test/test_regionprops.py +++ b/src/tracksdata/nodes/_test/test_regionprops.py @@ -79,8 +79,8 @@ def test_regionprops_add_nodes_2d() -> None: operator = RegionPropsNodes(extra_properties=extra_properties) operator.add_nodes(graph, labels=labels) - assert "shape" in graph.metadata() - assert graph.metadata()["shape"] == labels.shape + assert "shape" in graph.metadata + assert graph.metadata["shape"] == labels.shape # Check that nodes were added assert graph.num_nodes() == 2 # Two regions (labels 1 and 2) @@ -115,8 +115,8 @@ def test_regionprops_add_nodes_3d() -> None: operator = RegionPropsNodes(extra_properties=extra_properties) operator.add_nodes(graph, labels=labels) - assert "shape" in graph.metadata() - assert graph.metadata()["shape"] == labels.shape + assert "shape" in graph.metadata + assert graph.metadata["shape"] == labels.shape # Check that nodes were added assert graph.num_nodes() == 2 # Two regions @@ -150,8 +150,8 @@ def test_regionprops_add_nodes_with_intensity() -> None: operator.add_nodes(graph, labels=labels, intensity_image=intensity) - assert "shape" in graph.metadata() - assert graph.metadata()["shape"] == labels.shape + assert "shape" in graph.metadata + assert graph.metadata["shape"] == labels.shape # Check that nodes were added with intensity attributes nodes_df = graph.node_attrs() @@ -181,8 +181,8 @@ def test_regionprops_add_nodes_timelapse(n_workers: int) -> None: with options_context(n_workers=n_workers): operator.add_nodes(graph, labels=labels) - assert "shape" in graph.metadata() - assert graph.metadata()["shape"] == labels.shape + assert "shape" in graph.metadata + assert graph.metadata["shape"] == labels.shape # Check that nodes were added for both time points nodes_df = graph.node_attrs() @@ -209,8 +209,8 @@ def test_regionprops_add_nodes_timelapse_with_intensity() -> None: operator.add_nodes(graph, labels=labels, intensity_image=intensity) - assert "shape" in graph.metadata() - assert graph.metadata()["shape"] == labels.shape + assert "shape" in graph.metadata + assert graph.metadata["shape"] == labels.shape # Check that nodes were added with intensity attributes nodes_df = graph.node_attrs() @@ -237,8 +237,8 @@ def double_area(region: RegionProperties) -> float: operator.add_nodes(graph, labels=labels, t=0) - assert "shape" in graph.metadata() - assert graph.metadata()["shape"] == labels.shape + assert "shape" in graph.metadata + assert graph.metadata["shape"] == labels.shape # Check that custom property was calculated nodes_df = graph.node_attrs() @@ -275,8 +275,8 @@ def test_regionprops_mask_creation() -> None: operator.add_nodes(graph, labels=labels, t=0) - assert "shape" in graph.metadata() - assert graph.metadata()["shape"] == labels.shape + assert "shape" in graph.metadata + assert graph.metadata["shape"] == labels.shape # Check that masks were created nodes_df = graph.node_attrs() @@ -300,8 +300,8 @@ def test_regionprops_spacing() -> None: operator.add_nodes(graph, labels=labels, t=0) - assert "shape" in graph.metadata() - assert graph.metadata()["shape"] == labels.shape + assert "shape" in graph.metadata + assert graph.metadata["shape"] == labels.shape # Check that nodes were added (spacing affects internal calculations) nodes_df = graph.node_attrs() @@ -323,8 +323,8 @@ def test_regionprops_empty_labels() -> None: operator.add_nodes(graph, labels=labels, t=0) - assert "shape" in graph.metadata() - assert graph.metadata()["shape"] == labels.shape + assert "shape" in graph.metadata + assert graph.metadata["shape"] == labels.shape # No nodes should be added assert graph.num_nodes() == 0 From cff58981bd54328d07a6208e48c2410541504c31 Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Tue, 17 Feb 2026 14:11:08 +0900 Subject: [PATCH 03/30] added private metadata view --- src/tracksdata/graph/_base_graph.py | 63 ++++++++++++------- .../graph/_test/test_graph_backends.py | 9 ++- 2 files changed, 46 insertions(+), 26 deletions(-) diff --git a/src/tracksdata/graph/_base_graph.py b/src/tracksdata/graph/_base_graph.py index bfc8239b..28e8baa2 100644 --- a/src/tracksdata/graph/_base_graph.py +++ b/src/tracksdata/graph/_base_graph.py @@ -47,20 +47,27 @@ class MetadataView(dict[str, Any]): _MISSING = object() - def __init__(self, graph: "BaseGraph", data: dict[str, Any]) -> None: + def __init__( + self, + graph: "BaseGraph", + data: dict[str, Any], + *, + is_public: bool = True, + ) -> None: super().__init__(data) self._graph = graph + self._is_public = is_public def __setitem__(self, key: str, value: Any) -> None: - self._graph._set_public_metadata(**{key: value}) + self._graph._set_public_metadata(is_public=self._is_public, **{key: value}) super().__setitem__(key, value) def __delitem__(self, key: str) -> None: - self._graph._remove_public_metadata(key) + self._graph._remove_public_metadata(key, is_public=self._is_public) super().__delitem__(key) def pop(self, key: str, default: Any = _MISSING) -> Any: - self._graph._validate_public_metadata_key(key) + self._graph._validate_metadata_key(key, is_public=self._is_public) if key not in self: if default is self._MISSING: @@ -68,32 +75,32 @@ def pop(self, key: str, default: Any = _MISSING) -> Any: return default value = super().__getitem__(key) - self._graph._remove_metadata(key) + self._graph._remove_public_metadata(key, is_public=self._is_public) super().pop(key, None) return value def popitem(self) -> tuple[str, Any]: key, value = super().popitem() - self._graph._remove_metadata(key) + self._graph._remove_public_metadata(key, is_public=self._is_public) return key, value def clear(self) -> None: keys = list(self.keys()) for key in keys: - self._graph._remove_metadata(key) + self._graph._remove_public_metadata(key, is_public=self._is_public) super().clear() def setdefault(self, key: str, default: Any = None) -> Any: if key in self: return super().__getitem__(key) - self._graph._set_public_metadata(**{key: default}) + self._graph._set_public_metadata(is_public=self._is_public, **{key: default}) super().__setitem__(key, default) return default def update(self, *args, **kwargs) -> None: updates = dict(*args, **kwargs) if updates: - self._graph._set_public_metadata(**updates) + self._graph._set_public_metadata(is_public=self._is_public, **updates) super().update(updates) @@ -1244,9 +1251,7 @@ def from_other(cls: type[T], other: "BaseGraph", **kwargs) -> T: graph = cls(**kwargs) graph.metadata.update(other.metadata) - private_metadata = other._private_metadata() - if private_metadata: - graph._update_metadata(**private_metadata) + graph._private_metadata.update(other._private_metadata) current_node_attr_schemas = graph._node_attr_schemas() for k, v in other._node_attr_schemas().items(): @@ -1904,33 +1909,45 @@ def metadata(self) -> MetadataView: return MetadataView( graph=self, data={k: v for k, v in self._metadata().items() if not self._is_private_metadata_key(k)}, + is_public=True, + ) + + @property + def _private_metadata(self) -> MetadataView: + return MetadataView( + graph=self, + data={k: v for k, v in self._metadata().items() if self._is_private_metadata_key(k)}, + is_public=False, ) @classmethod def _is_private_metadata_key(cls, key: str) -> bool: return key.startswith(cls._PRIVATE_METADATA_PREFIX) - def _validate_public_metadata_key(self, key: str) -> None: + def _validate_metadata_key(self, key: str, *, is_public: bool) -> None: if not isinstance(key, str): raise TypeError(f"Metadata key must be a string. Got {type(key)}.") - if self._is_private_metadata_key(key): + is_private_key = self._is_private_metadata_key(key) + if is_public and is_private_key: raise ValueError(f"Metadata key '{key}' is reserved for internal use.") + if not is_public and not is_private_key: + raise ValueError( + f"Metadata key '{key}' is not private. Private metadata keys must start with " + f"'{self._PRIVATE_METADATA_PREFIX}'." + ) - def _validate_public_metadata_keys(self, keys: Sequence[str]) -> None: + def _validate_metadata_keys(self, keys: Sequence[str], *, is_public: bool) -> None: for key in keys: - self._validate_public_metadata_key(key) + self._validate_metadata_key(key, is_public=is_public) - def _set_public_metadata(self, **kwargs) -> None: - self._validate_public_metadata_keys(kwargs.keys()) + def _set_public_metadata(self, is_public: bool = True, **kwargs) -> None: + self._validate_metadata_keys(kwargs.keys(), is_public=is_public) self._update_metadata(**kwargs) - def _remove_public_metadata(self, key: str) -> None: - self._validate_public_metadata_key(key) + def _remove_public_metadata(self, key: str, *, is_public: bool = True) -> None: + self._validate_metadata_key(key, is_public=is_public) self._remove_metadata(key) - def _private_metadata(self) -> dict[str, Any]: - return {k: v for k, v in self._metadata().items() if self._is_private_metadata_key(k)} - @abc.abstractmethod def _metadata(self) -> dict[str, Any]: """ diff --git a/src/tracksdata/graph/_test/test_graph_backends.py b/src/tracksdata/graph/_test/test_graph_backends.py index 73a1161c..e9088c75 100644 --- a/src/tracksdata/graph/_test/test_graph_backends.py +++ b/src/tracksdata/graph/_test/test_graph_backends.py @@ -2514,7 +2514,7 @@ def test_metadata_multiple_dtypes(graph_backend: BaseGraph) -> None: def test_private_metadata_is_hidden_from_public_apis(graph_backend: BaseGraph) -> None: private_key = "__private_dtype_map" - graph_backend._update_metadata(**{private_key: {"x": "float64"}}) + graph_backend._private_metadata.update(**{private_key: {"x": "float64"}}) graph_backend.metadata.update(shape=[1, 2, 3]) public_metadata = graph_backend.metadata @@ -2527,8 +2527,11 @@ def test_private_metadata_is_hidden_from_public_apis(graph_backend: BaseGraph) - with pytest.raises(ValueError, match="reserved for internal use"): graph_backend.metadata.pop(private_key, None) - # Internal APIs can still remove private keys. - graph_backend._remove_metadata(private_key) + with pytest.raises(ValueError, match="is not private"): + graph_backend._private_metadata.update(shape=[1, 2, 3]) + + # Private metadata view can remove private keys. + graph_backend._private_metadata.pop(private_key, None) assert private_key not in graph_backend._metadata() From 68b01d40c6368a5120474dc38e88276f6eb121da Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Tue, 17 Feb 2026 14:14:06 +0900 Subject: [PATCH 04/30] renamed func --- src/tracksdata/graph/_base_graph.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/tracksdata/graph/_base_graph.py b/src/tracksdata/graph/_base_graph.py index 28e8baa2..03dc3a01 100644 --- a/src/tracksdata/graph/_base_graph.py +++ b/src/tracksdata/graph/_base_graph.py @@ -59,11 +59,11 @@ def __init__( self._is_public = is_public def __setitem__(self, key: str, value: Any) -> None: - self._graph._set_public_metadata(is_public=self._is_public, **{key: value}) + self._graph._set_metadata_with_validation(is_public=self._is_public, **{key: value}) super().__setitem__(key, value) def __delitem__(self, key: str) -> None: - self._graph._remove_public_metadata(key, is_public=self._is_public) + self._graph._remove_metadata_with_validation(key, is_public=self._is_public) super().__delitem__(key) def pop(self, key: str, default: Any = _MISSING) -> Any: @@ -75,32 +75,32 @@ def pop(self, key: str, default: Any = _MISSING) -> Any: return default value = super().__getitem__(key) - self._graph._remove_public_metadata(key, is_public=self._is_public) + self._graph._remove_metadata_with_validation(key, is_public=self._is_public) super().pop(key, None) return value def popitem(self) -> tuple[str, Any]: key, value = super().popitem() - self._graph._remove_public_metadata(key, is_public=self._is_public) + self._graph._remove_metadata_with_validation(key, is_public=self._is_public) return key, value def clear(self) -> None: keys = list(self.keys()) for key in keys: - self._graph._remove_public_metadata(key, is_public=self._is_public) + self._graph._remove_metadata_with_validation(key, is_public=self._is_public) super().clear() def setdefault(self, key: str, default: Any = None) -> Any: if key in self: return super().__getitem__(key) - self._graph._set_public_metadata(is_public=self._is_public, **{key: default}) + self._graph._set_metadata_with_validation(is_public=self._is_public, **{key: default}) super().__setitem__(key, default) return default def update(self, *args, **kwargs) -> None: updates = dict(*args, **kwargs) if updates: - self._graph._set_public_metadata(is_public=self._is_public, **updates) + self._graph._set_metadata_with_validation(is_public=self._is_public, **updates) super().update(updates) @@ -1940,11 +1940,11 @@ def _validate_metadata_keys(self, keys: Sequence[str], *, is_public: bool) -> No for key in keys: self._validate_metadata_key(key, is_public=is_public) - def _set_public_metadata(self, is_public: bool = True, **kwargs) -> None: + def _set_metadata_with_validation(self, is_public: bool = True, **kwargs) -> None: self._validate_metadata_keys(kwargs.keys(), is_public=is_public) self._update_metadata(**kwargs) - def _remove_public_metadata(self, key: str, *, is_public: bool = True) -> None: + def _remove_metadata_with_validation(self, key: str, *, is_public: bool = True) -> None: self._validate_metadata_key(key, is_public=is_public) self._remove_metadata(key) From 1ae242670a8051827723c926883359657980238a Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Tue, 17 Feb 2026 15:34:33 +0900 Subject: [PATCH 05/30] implementation of saving and loading dtypes as metadata --- src/tracksdata/graph/_base_graph.py | 72 ++++++++++++++++- src/tracksdata/graph/_rustworkx_graph.py | 45 ++++++----- src/tracksdata/graph/_sql_graph.py | 71 ++++++++++++----- .../graph/_test/test_graph_backends.py | 77 +++++++++++++++++++ src/tracksdata/utils/_dtypes.py | 61 +++++++++++++++ .../utils/_test/test_dtype_serialization.py | 45 +++++++++++ 6 files changed, 332 insertions(+), 39 deletions(-) create mode 100644 src/tracksdata/utils/_test/test_dtype_serialization.py diff --git a/src/tracksdata/graph/_base_graph.py b/src/tracksdata/graph/_base_graph.py index 03dc3a01..5c2c8f85 100644 --- a/src/tracksdata/graph/_base_graph.py +++ b/src/tracksdata/graph/_base_graph.py @@ -4,6 +4,7 @@ from collections.abc import Sequence from pathlib import Path from typing import TYPE_CHECKING, Any, Literal, TypeVar, overload +import warnings import geff import numpy as np @@ -21,7 +22,9 @@ from tracksdata.utils._dtypes import ( AttrSchema, column_to_numpy, + deserialize_polars_dtype, polars_dtype_to_numpy_dtype, + serialize_polars_dtype, ) from tracksdata.utils._logging import LOG from tracksdata.utils._multiprocessing import multiprocessing_apply @@ -110,6 +113,7 @@ class BaseGraph(abc.ABC): """ _PRIVATE_METADATA_PREFIX = "__private_" + _PRIVATE_DTYPE_MAP_KEY = "__private_dtype_map" node_added = Signal(int) node_removed = Signal(int) @@ -1281,7 +1285,6 @@ def from_other(cls: type[T], other: "BaseGraph", **kwargs) -> T: current_edge_attr_schemas = graph._edge_attr_schemas() for k, v in other._edge_attr_schemas().items(): if k not in current_edge_attr_schemas: - print(f"Adding edge attribute key: {k} with dtype: {v.dtype} and default value: {v.default_value}") graph.add_edge_attr_key(k, v.dtype, v.default_value) edge_attrs = edge_attrs.with_columns( @@ -1948,6 +1951,73 @@ def _remove_metadata_with_validation(self, key: str, *, is_public: bool = True) self._validate_metadata_key(key, is_public=is_public) self._remove_metadata(key) + def _get_private_dtype_map(self) -> dict[str, dict[str, str]]: + dtype_map = self._private_metadata.get(self._PRIVATE_DTYPE_MAP_KEY, {}) + if not isinstance(dtype_map, dict): + return {"node": {}, "edge": {}} + + node_dtype_map = dtype_map.get("node", {}) + edge_dtype_map = dtype_map.get("edge", {}) + if not isinstance(node_dtype_map, dict): + node_dtype_map = {} + if not isinstance(edge_dtype_map, dict): + edge_dtype_map = {} + + return {"node": dict(node_dtype_map), "edge": dict(edge_dtype_map)} + + def _set_private_dtype_map(self, dtype_map: dict[str, dict[str, str]]) -> None: + self._private_metadata.update( + **{ + self._PRIVATE_DTYPE_MAP_KEY: { + "node": dict(dtype_map.get("node", {})), + "edge": dict(dtype_map.get("edge", {})), + } + } + ) + + def _set_attr_dtype_metadata(self, *, key: str, dtype: pl.DataType, is_node: bool) -> None: + dtype_map = self._get_private_dtype_map() + map_key = "node" if is_node else "edge" + dtype_map[map_key][key] = serialize_polars_dtype(dtype) + self._set_private_dtype_map(dtype_map) + + def _remove_attr_dtype_metadata(self, *, key: str, is_node: bool) -> None: + dtype_map = self._get_private_dtype_map() + map_key = "node" if is_node else "edge" + dtype_map[map_key].pop(key, None) + self._set_private_dtype_map(dtype_map) + + def _attr_dtype_from_metadata(self, *, key: str, is_node: bool) -> pl.DataType | None: + dtype_map = self._get_private_dtype_map() + map_key = "node" if is_node else "edge" + encoded_dtype = dtype_map[map_key].get(key) + if not isinstance(encoded_dtype, str): + return None + + try: + return deserialize_polars_dtype(encoded_dtype) + except Exception: + warnings.warn( + f"Initializing schemas from existing database tables for the key {key}. " + "This is a fallback mechanism when loading existing graphs, and may not perfectly restore the original schemas. " + "This method is deprecated and will be removed in the major release. ", + UserWarning, + ) + return None + + def _sync_attr_dtype_metadata(self) -> None: + dtype_map = { + "node": { + key: serialize_polars_dtype(schema.dtype) + for key, schema in self._node_attr_schemas().items() + }, + "edge": { + key: serialize_polars_dtype(schema.dtype) + for key, schema in self._edge_attr_schemas().items() + }, + } + self._set_private_dtype_map(dtype_map) + @abc.abstractmethod def _metadata(self) -> dict[str, Any]: """ diff --git a/src/tracksdata/graph/_rustworkx_graph.py b/src/tracksdata/graph/_rustworkx_graph.py index 05cf1c17..a15dd1f5 100644 --- a/src/tracksdata/graph/_rustworkx_graph.py +++ b/src/tracksdata/graph/_rustworkx_graph.py @@ -400,11 +400,13 @@ def __init__(self, rx_graph: rx.PyDiGraph | None = None) -> None: for key, value in first_node_attrs.items(): if key == DEFAULT_ATTR_KEYS.NODE_ID: continue - try: - dtype = pl.Series([value]).dtype - except (ValueError, TypeError): - # If polars can't infer dtype (e.g., for complex objects), use Object - dtype = pl.Object + dtype = self._attr_dtype_from_metadata(key=key, is_node=True) + if dtype is None: + try: + dtype = pl.Series([value]).dtype + except (ValueError, TypeError): + # If polars can't infer dtype (e.g., for complex objects), use Object + dtype = pl.Object self.__node_attr_schemas[key] = AttrSchema(key=key, dtype=dtype) # Process edges: set edge IDs and infer schemas @@ -422,13 +424,17 @@ def __init__(self, rx_graph: rx.PyDiGraph | None = None) -> None: # TODO: check if EDGE_SOURCE and EDGE_TARGET should be also ignored or in the schema if key == DEFAULT_ATTR_KEYS.EDGE_ID: continue - try: - dtype = pl.Series([value]).dtype - except (ValueError, TypeError): - # If polars can't infer dtype (e.g., for complex objects), use Object - dtype = pl.Object + dtype = self._attr_dtype_from_metadata(key=key, is_node=False) + if dtype is None: + try: + dtype = pl.Series([value]).dtype + except (ValueError, TypeError): + # If polars can't infer dtype (e.g., for complex objects), use Object + dtype = pl.Object self.__edge_attr_schemas[key] = AttrSchema(key=key, dtype=dtype) + self._sync_attr_dtype_metadata() + def _node_attr_schemas(self) -> dict[str, AttrSchema]: return self.__node_attr_schemas @@ -986,6 +992,7 @@ def add_node_attr_key( # Store schema self.__node_attr_schemas[schema.key] = schema + self._set_attr_dtype_metadata(key=schema.key, dtype=schema.dtype, is_node=True) def remove_node_attr_key(self, key: str) -> None: """ @@ -998,6 +1005,7 @@ def remove_node_attr_key(self, key: str) -> None: raise ValueError(f"Cannot remove required node attribute key {key}") del self.__node_attr_schemas[key] + self._remove_attr_dtype_metadata(key=key, is_node=True) for node_attr in self.rx_graph.nodes(): node_attr.pop(key, None) @@ -1026,6 +1034,7 @@ def add_edge_attr_key( # Store schema self.__edge_attr_schemas[schema.key] = schema + self._set_attr_dtype_metadata(key=schema.key, dtype=schema.dtype, is_node=False) def remove_edge_attr_key(self, key: str) -> None: """ @@ -1035,6 +1044,7 @@ def remove_edge_attr_key(self, key: str) -> None: raise ValueError(f"Edge attribute key {key} does not exist") del self.__edge_attr_schemas[key] + self._remove_attr_dtype_metadata(key=key, is_node=False) for edge_attr in self.rx_graph.edges(): edge_attr.pop(key, None) @@ -1153,16 +1163,11 @@ def edge_attrs( edge_map = rx_graph.edge_index_map() if len(edge_map) == 0: - return pl.DataFrame( - { - key: [] - for key in [ - *attr_keys, - DEFAULT_ATTR_KEYS.EDGE_SOURCE, - DEFAULT_ATTR_KEYS.EDGE_TARGET, - ] - } - ) + empty_columns = {} + for key in [*attr_keys, DEFAULT_ATTR_KEYS.EDGE_SOURCE, DEFAULT_ATTR_KEYS.EDGE_TARGET]: + schema = self._edge_attr_schemas()[key] + empty_columns[key] = pl.Series(name=key, values=[], dtype=schema.dtype) + return pl.DataFrame(empty_columns) source, target, data = zip(*edge_map.values(), strict=False) diff --git a/src/tracksdata/graph/_sql_graph.py b/src/tracksdata/graph/_sql_graph.py index c8ea38ed..b36e3ab8 100644 --- a/src/tracksdata/graph/_sql_graph.py +++ b/src/tracksdata/graph/_sql_graph.py @@ -479,6 +479,7 @@ def __init__( # Initialize schemas from existing table columns self._init_schemas_from_tables() + self._sync_attr_dtype_metadata() self._max_id_per_time = {} self._update_max_id_per_time() @@ -556,12 +557,19 @@ def _init_schemas_from_tables(self) -> None: Initialize AttrSchema objects from existing database table columns. This is used when loading an existing graph from the database. """ + + node_column_names = list(self.Node.__table__.columns.keys()) + preferred_node_order = [DEFAULT_ATTR_KEYS.T, DEFAULT_ATTR_KEYS.NODE_ID] + ordered_node_columns = [name for name in preferred_node_order if name in node_column_names] + ordered_node_columns.extend(name for name in node_column_names if name not in preferred_node_order) + # Initialize node schemas from Node table columns - for column_name in self.Node.__table__.columns.keys(): + for column_name in ordered_node_columns: if column_name not in self.__node_attr_schemas: - column = self.Node.__table__.columns[column_name] - # Infer polars dtype from SQLAlchemy type - pl_dtype = sqlalchemy_type_to_polars_dtype(column.type) + pl_dtype = self._attr_dtype_from_metadata(key=column_name, is_node=True) + if pl_dtype is None: + column = self.Node.__table__.columns[column_name] + pl_dtype = sqlalchemy_type_to_polars_dtype(column.type) # AttrSchema.__post_init__ will infer the default_value self.__node_attr_schemas[column_name] = AttrSchema( key=column_name, @@ -572,9 +580,10 @@ def _init_schemas_from_tables(self) -> None: for column_name in self.Edge.__table__.columns.keys(): # Skip internal edge columns if column_name not in self.__edge_attr_schemas: - column = self.Edge.__table__.columns[column_name] - # Infer polars dtype from SQLAlchemy type - pl_dtype = sqlalchemy_type_to_polars_dtype(column.type) + pl_dtype = self._attr_dtype_from_metadata(key=column_name, is_node=False) + if pl_dtype is None: + column = self.Edge.__table__.columns[column_name] + pl_dtype = sqlalchemy_type_to_polars_dtype(column.type) # AttrSchema.__post_init__ will infer the default_value self.__edge_attr_schemas[column_name] = AttrSchema( key=column_name, @@ -593,11 +602,16 @@ def _polars_schema_override(self, table_class: type[DeclarativeBase]) -> SchemaD else: schemas = self._edge_attr_schemas() - # Return schema overrides for special types that need explicit casting + # Return schema overrides for columns safely represented in SQL. + # Pickled columns are unpickled and casted in a second pass. return { key: schema.dtype for key, schema in schemas.items() - if not (schema.dtype == pl.Object or isinstance(schema.dtype, pl.Array | pl.List)) + if ( + key in table_class.__table__.columns + and not isinstance(table_class.__table__.columns[key].type, sa.PickleType | sa.LargeBinary) + and not (schema.dtype == pl.Object or isinstance(schema.dtype, pl.Array | pl.List)) + ) } def _cast_array_columns(self, table_class: type[DeclarativeBase], df: pl.DataFrame) -> pl.DataFrame: @@ -607,12 +621,19 @@ def _cast_array_columns(self, table_class: type[DeclarativeBase], df: pl.DataFra else: schemas = self._edge_attr_schemas() - # Cast array columns (stored as blobs in database) - df = df.with_columns( - pl.Series(key, df[key].to_list(), dtype=schema.dtype) - for key, schema in schemas.items() - if isinstance(schema.dtype, pl.Array) and key in df.columns - ) + casts: list[pl.Series] = [] + for key, schema in schemas.items(): + if key not in df.columns: + continue + + try: + casts.append(pl.Series(key, df[key].to_list(), dtype=schema.dtype)) + except Exception: + # Keep original dtype when values cannot be casted to the target schema. + continue + + if casts: + df = df.with_columns(casts) return df def _update_max_id_per_time(self) -> None: @@ -1289,6 +1310,8 @@ def node_attrs( # indices are included by default and must be removed if attr_keys is not None: nodes_df = nodes_df.select([pl.col(c) for c in attr_keys]) + else: + nodes_df = nodes_df.select([pl.col(c) for c in self._node_attr_schemas() if c in nodes_df.columns]) if unpack: nodes_df = unpack_array_attrs(nodes_df) @@ -1331,6 +1354,8 @@ def edge_attrs( if unpack: edges_df = unpack_array_attrs(edges_df) + elif attr_keys is None: + edges_df = edges_df.select([pl.col(c) for c in self._edge_attr_schemas() if c in edges_df.columns]) return edges_df @@ -1575,6 +1600,9 @@ def _add_new_column( sa_column = sa.Column(schema.key, sa_type, default=default_value) str_dialect_type = sa_column.type.compile(dialect=self._engine.dialect) + identifier_preparer = self._engine.dialect.identifier_preparer + quoted_table_name = identifier_preparer.format_table(table_class.__table__) + quoted_column_name = identifier_preparer.quote(sa_column.name) # Properly quote default values based on type if isinstance(default_value, str): @@ -1585,8 +1613,8 @@ def _add_new_column( quoted_default = str(default_value) add_column_stmt = sa.DDL( - f"ALTER TABLE {table_class.__table__} ADD " - f"COLUMN {sa_column.name} {str_dialect_type} " + f"ALTER TABLE {quoted_table_name} ADD " + f"COLUMN {quoted_column_name} {str_dialect_type} " f"DEFAULT {quoted_default}", ) LOG.info("add %s column statement:\n'%s'", table_class.__table__, add_column_stmt) @@ -1601,7 +1629,10 @@ def _add_new_column( table_class.__table__.append_column(sa_column) def _drop_column(self, table_class: type[DeclarativeBase], key: str) -> None: - drop_column_stmt = sa.DDL(f"ALTER TABLE {table_class.__table__} DROP COLUMN {key}") + identifier_preparer = self._engine.dialect.identifier_preparer + quoted_table_name = identifier_preparer.format_table(table_class.__table__) + quoted_column_name = identifier_preparer.quote(key) + drop_column_stmt = sa.DDL(f"ALTER TABLE {quoted_table_name} DROP COLUMN {quoted_column_name}") LOG.info("drop %s column statement:\n'%s'", table_class.__table__, drop_column_stmt) with Session(self._engine) as session: @@ -1625,6 +1656,7 @@ def add_node_attr_key( # Add column to database self._add_new_column(self.Node, schema) + self._set_attr_dtype_metadata(key=schema.key, dtype=schema.dtype, is_node=True) def remove_node_attr_key(self, key: str) -> None: if key not in self.node_attr_keys(): @@ -1635,6 +1667,7 @@ def remove_node_attr_key(self, key: str) -> None: self._drop_column(self.Node, key) self.__node_attr_schemas.pop(key, None) + self._remove_attr_dtype_metadata(key=key, is_node=True) def add_edge_attr_key( self, @@ -1650,6 +1683,7 @@ def add_edge_attr_key( # Add column to database self._add_new_column(self.Edge, schema) + self._set_attr_dtype_metadata(key=schema.key, dtype=schema.dtype, is_node=False) def remove_edge_attr_key(self, key: str) -> None: if key not in self.edge_attr_keys(): @@ -1657,6 +1691,7 @@ def remove_edge_attr_key(self, key: str) -> None: self._drop_column(self.Edge, key) self.__edge_attr_schemas.pop(key, None) + self._remove_attr_dtype_metadata(key=key, is_node=False) def num_edges(self) -> int: with Session(self._engine) as session: diff --git a/src/tracksdata/graph/_test/test_graph_backends.py b/src/tracksdata/graph/_test/test_graph_backends.py index e9088c75..f1ede1b2 100644 --- a/src/tracksdata/graph/_test/test_graph_backends.py +++ b/src/tracksdata/graph/_test/test_graph_backends.py @@ -1437,6 +1437,83 @@ def test_from_other_with_edges( assert new_overlaps == source_overlaps +@pytest.mark.parametrize( + ("target_cls", "target_kwargs"), + [ + pytest.param(RustWorkXGraph, {}, id="rustworkx"), + pytest.param( + SQLGraph, + { + "drivername": "sqlite", + "database": ":memory:", + "engine_kwargs": {"connect_args": {"check_same_thread": False}}, + }, + id="sql", + ), + pytest.param(IndexedRXGraph, {}, id="indexed"), + ], +) +def test_from_other_preserves_schema_roundtrip(target_cls: type[BaseGraph], target_kwargs: dict[str, Any]) -> None: + """Test that from_other preserves node and edge attribute schemas across backends.""" + graph = RustWorkXGraph() + for dtype in [ + pl.Float16, pl.Float32, + pl.Float64, + pl.Int8, + pl.Int16, + pl.Int32, + pl.Int64, + pl.UInt8, + pl.UInt16, + pl.UInt32, + pl.UInt64, + pl.Date, pl.Datetime, + pl.Boolean, + pl.Array(pl.Float32, 3), + pl.List(pl.Int32), + pl.Struct({"a": pl.Int8, "b": pl.Array(pl.String, 2)}), + pl.String, + pl.Object]: + graph.add_node_attr_key(f"attr_{dtype}", dtype=dtype) + graph.add_node({"t":0, + "attr_Float16": np.float16(1.5), + "attr_Float32": np.float32(2.5), + "attr_Float64": np.float64(3.5), + "attr_Int8": np.int8(4), + "attr_Int16": np.int16(5), + "attr_Int32": np.int32(6), + "attr_Int64": np.int64(7), + "attr_UInt8": np.uint8(8), + "attr_UInt16": np.uint16(9), + "attr_UInt32": np.uint32(10), + "attr_UInt64": np.uint64(11), + "attr_Date": pl.date(2024, 1, 1), + "attr_Datetime": pl.datetime(2024, 1, 1, 12, 0, 0), + "attr_Boolean": True, + "attr_Array(Float32, shape=(3,))": np.array([1.0, 2.0, 3.0], dtype=np.float32), + "attr_List(Int32)": [1, 2, 3], + "attr_Struct({'a': Int8, 'b': Array(String, shape=(2,))})": {"a": 1, "b": np.array(["x", "y"], dtype=object)}, + "attr_String": "test", + "attr_Object": {"key": "value"}}) + graph2 = target_cls.from_other(graph, **target_kwargs) + + assert graph2.num_nodes() == graph.num_nodes() + assert set(graph2.node_attr_keys()) == set(graph.node_attr_keys()) + + assert graph2._node_attr_schemas() == graph._node_attr_schemas() + assert graph2._edge_attr_schemas() == graph._edge_attr_schemas() + assert graph2.node_attrs().schema == graph.node_attrs().schema + assert graph2.edge_attrs().schema == graph.edge_attrs().schema + + graph3 = RustWorkXGraph.from_other(graph2) + assert graph3._node_attr_schemas() == graph._node_attr_schemas() + assert graph3._edge_attr_schemas() == graph._edge_attr_schemas() + assert graph3.node_attrs().schema == graph.node_attrs().schema + assert graph3.edge_attrs().schema == graph.edge_attrs().schema + + + + @pytest.mark.parametrize( ("target_cls", "target_kwargs"), [ diff --git a/src/tracksdata/utils/_dtypes.py b/src/tracksdata/utils/_dtypes.py index 8e671487..f0f24e25 100644 --- a/src/tracksdata/utils/_dtypes.py +++ b/src/tracksdata/utils/_dtypes.py @@ -1,6 +1,8 @@ from __future__ import annotations +import base64 from dataclasses import dataclass +import io from typing import Any import numpy as np @@ -202,6 +204,37 @@ def copy(self) -> AttrSchema: """ return AttrSchema(key=self.key, dtype=self.dtype, default_value=self.default_value) + def __eq__(self, other: object) -> bool: + if not isinstance(other, AttrSchema): + return NotImplemented + return ( + self.key == other.key + and self.dtype == other.dtype + and _values_equal(self.default_value, other.default_value) + ) + + +def _values_equal(left: Any, right: Any) -> bool: + if isinstance(left, np.ndarray) and isinstance(right, np.ndarray): + return bool(np.array_equal(left, right)) + if isinstance(left, dict) and isinstance(right, dict): + if left.keys() != right.keys(): + return False + return all(_values_equal(left[k], right[k]) for k in left) + if isinstance(left, list | tuple) and isinstance(right, list | tuple): + if len(left) != len(right): + return False + return all(_values_equal(lv, rv) for lv, rv in zip(left, right, strict=True)) + + try: + value = left == right + except Exception: + return False + + if isinstance(value, np.ndarray): + return bool(np.all(value)) + return bool(value) + def process_attr_key_args( key_or_schema: str | AttrSchema, @@ -445,6 +478,34 @@ def sqlalchemy_type_to_polars_dtype(sa_type: TypeEngine) -> pl.DataType: return pl.Object +def serialize_polars_dtype(dtype: pl.DataType) -> str: + """ + Serializes a Polars dtype to a safe, cross-platform base64 string + using the Arrow IPC format. + """ + # Wrap the dtype in an empty DataFrame schema + # We use an empty DataFrame so no actual data is processed, only metadata. + dummy_df = pl.DataFrame(schema={"dummy": dtype}) + # Write to Arrow IPC (binary buffer) + # IPC is stable across versions/platforms unlike internal serialization. + buffer = io.BytesIO() + dummy_df.write_ipc(buffer) + # Encode binary to a standard string + return base64.b64encode(buffer.getvalue()).decode('utf-8') + + +def deserialize_polars_dtype(encoded_dtype: str) -> pl.DataType: + """ + Recovers a Polars dtype from a base64 string. + """ + # Decode string back to binary + data = base64.b64decode(encoded_dtype) + # Read the IPC buffer + buffer = io.BytesIO(data) + restored_df = pl.read_ipc(buffer) + # Extract the dtype from the schema + return restored_df.schema["dummy"] + def validate_default_value_dtype_compatibility(default_value: Any, dtype: pl.DataType) -> None: """ Validate that a default value is compatible with a polars dtype. diff --git a/src/tracksdata/utils/_test/test_dtype_serialization.py b/src/tracksdata/utils/_test/test_dtype_serialization.py new file mode 100644 index 00000000..3a997209 --- /dev/null +++ b/src/tracksdata/utils/_test/test_dtype_serialization.py @@ -0,0 +1,45 @@ +import base64 +import binascii + +import polars as pl +import pytest + +from tracksdata.utils._dtypes import deserialize_polars_dtype, serialize_polars_dtype + + +@pytest.mark.parametrize( + "dtype", + [ + pl.Int64, + pl.Float32, + pl.Boolean, + pl.String, + pl.List(pl.Int16), + pl.Array(pl.Float64, 4), + pl.Array(pl.Int32, (2, 3)), + pl.Struct({"x": pl.Int64, "y": pl.List(pl.String)}), + pl.Datetime("us", "UTC"), + ], +) +def test_serialize_deserialize_polars_dtype_roundtrip(dtype: pl.DataType) -> None: + encoded = serialize_polars_dtype(dtype) + + assert isinstance(encoded, str) + assert encoded + assert base64.b64decode(encoded) + + restored_dtype = deserialize_polars_dtype(encoded) + + assert restored_dtype == dtype + + +def test_deserialize_polars_dtype_invalid_base64_raises() -> None: + with pytest.raises(binascii.Error): + deserialize_polars_dtype("not-base64") + + +def test_deserialize_polars_dtype_non_ipc_payload_raises() -> None: + encoded = base64.b64encode(b"not-arrow-ipc").decode("utf-8") + + with pytest.raises((OSError, pl.exceptions.PolarsError)): + deserialize_polars_dtype(encoded) From c50a07b9eaf33665312b3b8db58ac8107b6d2a77 Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Tue, 17 Feb 2026 15:48:05 +0900 Subject: [PATCH 06/30] lint --- src/tracksdata/graph/_base_graph.py | 16 ++-- src/tracksdata/graph/_sql_graph.py | 2 +- .../graph/_test/test_graph_backends.py | 86 ++++++++++--------- src/tracksdata/utils/_dtypes.py | 5 +- 4 files changed, 57 insertions(+), 52 deletions(-) diff --git a/src/tracksdata/graph/_base_graph.py b/src/tracksdata/graph/_base_graph.py index 5c2c8f85..9f35119e 100644 --- a/src/tracksdata/graph/_base_graph.py +++ b/src/tracksdata/graph/_base_graph.py @@ -1,10 +1,10 @@ import abc import functools import operator +import warnings from collections.abc import Sequence from pathlib import Path from typing import TYPE_CHECKING, Any, Literal, TypeVar, overload -import warnings import geff import numpy as np @@ -1999,22 +1999,18 @@ def _attr_dtype_from_metadata(self, *, key: str, is_node: bool) -> pl.DataType | except Exception: warnings.warn( f"Initializing schemas from existing database tables for the key {key}. " - "This is a fallback mechanism when loading existing graphs, and may not perfectly restore the original schemas. " + "This is a fallback mechanism when loading existing graphs, and may not " + "perfectly restore the original schemas. " "This method is deprecated and will be removed in the major release. ", UserWarning, + stacklevel=2, ) return None def _sync_attr_dtype_metadata(self) -> None: dtype_map = { - "node": { - key: serialize_polars_dtype(schema.dtype) - for key, schema in self._node_attr_schemas().items() - }, - "edge": { - key: serialize_polars_dtype(schema.dtype) - for key, schema in self._edge_attr_schemas().items() - }, + "node": {key: serialize_polars_dtype(schema.dtype) for key, schema in self._node_attr_schemas().items()}, + "edge": {key: serialize_polars_dtype(schema.dtype) for key, schema in self._edge_attr_schemas().items()}, } self._set_private_dtype_map(dtype_map) diff --git a/src/tracksdata/graph/_sql_graph.py b/src/tracksdata/graph/_sql_graph.py index b36e3ab8..735a10d3 100644 --- a/src/tracksdata/graph/_sql_graph.py +++ b/src/tracksdata/graph/_sql_graph.py @@ -557,7 +557,7 @@ def _init_schemas_from_tables(self) -> None: Initialize AttrSchema objects from existing database table columns. This is used when loading an existing graph from the database. """ - + node_column_names = list(self.Node.__table__.columns.keys()) preferred_node_order = [DEFAULT_ATTR_KEYS.T, DEFAULT_ATTR_KEYS.NODE_ID] ordered_node_columns = [name for name in preferred_node_order if name in node_column_names] diff --git a/src/tracksdata/graph/_test/test_graph_backends.py b/src/tracksdata/graph/_test/test_graph_backends.py index f1ede1b2..0bc7dcf9 100644 --- a/src/tracksdata/graph/_test/test_graph_backends.py +++ b/src/tracksdata/graph/_test/test_graph_backends.py @@ -1457,44 +1457,54 @@ def test_from_other_preserves_schema_roundtrip(target_cls: type[BaseGraph], targ """Test that from_other preserves node and edge attribute schemas across backends.""" graph = RustWorkXGraph() for dtype in [ - pl.Float16, pl.Float32, - pl.Float64, - pl.Int8, - pl.Int16, - pl.Int32, - pl.Int64, - pl.UInt8, - pl.UInt16, - pl.UInt32, - pl.UInt64, - pl.Date, pl.Datetime, - pl.Boolean, - pl.Array(pl.Float32, 3), - pl.List(pl.Int32), - pl.Struct({"a": pl.Int8, "b": pl.Array(pl.String, 2)}), - pl.String, - pl.Object]: + pl.Float16, + pl.Float32, + pl.Float64, + pl.Int8, + pl.Int16, + pl.Int32, + pl.Int64, + pl.UInt8, + pl.UInt16, + pl.UInt32, + pl.UInt64, + pl.Date, + pl.Datetime, + pl.Boolean, + pl.Array(pl.Float32, 3), + pl.List(pl.Int32), + pl.Struct({"a": pl.Int8, "b": pl.Array(pl.String, 2)}), + pl.String, + pl.Object, + ]: graph.add_node_attr_key(f"attr_{dtype}", dtype=dtype) - graph.add_node({"t":0, - "attr_Float16": np.float16(1.5), - "attr_Float32": np.float32(2.5), - "attr_Float64": np.float64(3.5), - "attr_Int8": np.int8(4), - "attr_Int16": np.int16(5), - "attr_Int32": np.int32(6), - "attr_Int64": np.int64(7), - "attr_UInt8": np.uint8(8), - "attr_UInt16": np.uint16(9), - "attr_UInt32": np.uint32(10), - "attr_UInt64": np.uint64(11), - "attr_Date": pl.date(2024, 1, 1), - "attr_Datetime": pl.datetime(2024, 1, 1, 12, 0, 0), - "attr_Boolean": True, - "attr_Array(Float32, shape=(3,))": np.array([1.0, 2.0, 3.0], dtype=np.float32), - "attr_List(Int32)": [1, 2, 3], - "attr_Struct({'a': Int8, 'b': Array(String, shape=(2,))})": {"a": 1, "b": np.array(["x", "y"], dtype=object)}, - "attr_String": "test", - "attr_Object": {"key": "value"}}) + graph.add_node( + { + "t": 0, + "attr_Float16": np.float16(1.5), + "attr_Float32": np.float32(2.5), + "attr_Float64": np.float64(3.5), + "attr_Int8": np.int8(4), + "attr_Int16": np.int16(5), + "attr_Int32": np.int32(6), + "attr_Int64": np.int64(7), + "attr_UInt8": np.uint8(8), + "attr_UInt16": np.uint16(9), + "attr_UInt32": np.uint32(10), + "attr_UInt64": np.uint64(11), + "attr_Date": pl.date(2024, 1, 1), + "attr_Datetime": pl.datetime(2024, 1, 1, 12, 0, 0), + "attr_Boolean": True, + "attr_Array(Float32, shape=(3,))": np.array([1.0, 2.0, 3.0], dtype=np.float32), + "attr_List(Int32)": [1, 2, 3], + "attr_Struct({'a': Int8, 'b': Array(String, shape=(2,))})": { + "a": 1, + "b": np.array(["x", "y"], dtype=object), + }, + "attr_String": "test", + "attr_Object": {"key": "value"}, + } + ) graph2 = target_cls.from_other(graph, **target_kwargs) assert graph2.num_nodes() == graph.num_nodes() @@ -1512,8 +1522,6 @@ def test_from_other_preserves_schema_roundtrip(target_cls: type[BaseGraph], targ assert graph3.edge_attrs().schema == graph.edge_attrs().schema - - @pytest.mark.parametrize( ("target_cls", "target_kwargs"), [ diff --git a/src/tracksdata/utils/_dtypes.py b/src/tracksdata/utils/_dtypes.py index f0f24e25..90fc2006 100644 --- a/src/tracksdata/utils/_dtypes.py +++ b/src/tracksdata/utils/_dtypes.py @@ -1,8 +1,8 @@ from __future__ import annotations import base64 -from dataclasses import dataclass import io +from dataclasses import dataclass from typing import Any import numpy as np @@ -491,7 +491,7 @@ def serialize_polars_dtype(dtype: pl.DataType) -> str: buffer = io.BytesIO() dummy_df.write_ipc(buffer) # Encode binary to a standard string - return base64.b64encode(buffer.getvalue()).decode('utf-8') + return base64.b64encode(buffer.getvalue()).decode("utf-8") def deserialize_polars_dtype(encoded_dtype: str) -> pl.DataType: @@ -506,6 +506,7 @@ def deserialize_polars_dtype(encoded_dtype: str) -> pl.DataType: # Extract the dtype from the schema return restored_df.schema["dummy"] + def validate_default_value_dtype_compatibility(default_value: Any, dtype: pl.DataType) -> None: """ Validate that a default value is compatible with a polars dtype. From e9bf28f90f4591b4fb82c1a6c93e7703fbe8aae3 Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Wed, 18 Feb 2026 10:45:18 +0900 Subject: [PATCH 07/30] restricted dtype metadata to sqlgraph --- src/tracksdata/graph/_base_graph.py | 73 +-------- src/tracksdata/graph/_rustworkx_graph.py | 32 ++-- src/tracksdata/graph/_sql_graph.py | 151 +++++++++++++----- .../graph/_test/test_graph_backends.py | 64 ++++++++ src/tracksdata/utils/_dtypes.py | 91 +++++++++++ .../utils/_test/test_dtype_serialization.py | 27 +++- 6 files changed, 310 insertions(+), 128 deletions(-) diff --git a/src/tracksdata/graph/_base_graph.py b/src/tracksdata/graph/_base_graph.py index 9f35119e..90500b67 100644 --- a/src/tracksdata/graph/_base_graph.py +++ b/src/tracksdata/graph/_base_graph.py @@ -1,7 +1,6 @@ import abc import functools import operator -import warnings from collections.abc import Sequence from pathlib import Path from typing import TYPE_CHECKING, Any, Literal, TypeVar, overload @@ -22,9 +21,7 @@ from tracksdata.utils._dtypes import ( AttrSchema, column_to_numpy, - deserialize_polars_dtype, polars_dtype_to_numpy_dtype, - serialize_polars_dtype, ) from tracksdata.utils._logging import LOG from tracksdata.utils._multiprocessing import multiprocessing_apply @@ -113,7 +110,6 @@ class BaseGraph(abc.ABC): """ _PRIVATE_METADATA_PREFIX = "__private_" - _PRIVATE_DTYPE_MAP_KEY = "__private_dtype_map" node_added = Signal(int) node_removed = Signal(int) @@ -1255,7 +1251,7 @@ def from_other(cls: type[T], other: "BaseGraph", **kwargs) -> T: graph = cls(**kwargs) graph.metadata.update(other.metadata) - graph._private_metadata.update(other._private_metadata) + graph._private_metadata.update(other._private_metadata_for_copy()) current_node_attr_schemas = graph._node_attr_schemas() for k, v in other._node_attr_schemas().items(): @@ -1951,68 +1947,13 @@ def _remove_metadata_with_validation(self, key: str, *, is_public: bool = True) self._validate_metadata_key(key, is_public=is_public) self._remove_metadata(key) - def _get_private_dtype_map(self) -> dict[str, dict[str, str]]: - dtype_map = self._private_metadata.get(self._PRIVATE_DTYPE_MAP_KEY, {}) - if not isinstance(dtype_map, dict): - return {"node": {}, "edge": {}} - - node_dtype_map = dtype_map.get("node", {}) - edge_dtype_map = dtype_map.get("edge", {}) - if not isinstance(node_dtype_map, dict): - node_dtype_map = {} - if not isinstance(edge_dtype_map, dict): - edge_dtype_map = {} - - return {"node": dict(node_dtype_map), "edge": dict(edge_dtype_map)} - - def _set_private_dtype_map(self, dtype_map: dict[str, dict[str, str]]) -> None: - self._private_metadata.update( - **{ - self._PRIVATE_DTYPE_MAP_KEY: { - "node": dict(dtype_map.get("node", {})), - "edge": dict(dtype_map.get("edge", {})), - } - } - ) - - def _set_attr_dtype_metadata(self, *, key: str, dtype: pl.DataType, is_node: bool) -> None: - dtype_map = self._get_private_dtype_map() - map_key = "node" if is_node else "edge" - dtype_map[map_key][key] = serialize_polars_dtype(dtype) - self._set_private_dtype_map(dtype_map) - - def _remove_attr_dtype_metadata(self, *, key: str, is_node: bool) -> None: - dtype_map = self._get_private_dtype_map() - map_key = "node" if is_node else "edge" - dtype_map[map_key].pop(key, None) - self._set_private_dtype_map(dtype_map) - - def _attr_dtype_from_metadata(self, *, key: str, is_node: bool) -> pl.DataType | None: - dtype_map = self._get_private_dtype_map() - map_key = "node" if is_node else "edge" - encoded_dtype = dtype_map[map_key].get(key) - if not isinstance(encoded_dtype, str): - return None - - try: - return deserialize_polars_dtype(encoded_dtype) - except Exception: - warnings.warn( - f"Initializing schemas from existing database tables for the key {key}. " - "This is a fallback mechanism when loading existing graphs, and may not " - "perfectly restore the original schemas. " - "This method is deprecated and will be removed in the major release. ", - UserWarning, - stacklevel=2, - ) - return None + def _private_metadata_for_copy(self) -> dict[str, Any]: + """ + Return private metadata entries that should be propagated by `from_other`. - def _sync_attr_dtype_metadata(self) -> None: - dtype_map = { - "node": {key: serialize_polars_dtype(schema.dtype) for key, schema in self._node_attr_schemas().items()}, - "edge": {key: serialize_polars_dtype(schema.dtype) for key, schema in self._edge_attr_schemas().items()}, - } - self._set_private_dtype_map(dtype_map) + Backends can override this to exclude backend-specific private metadata. + """ + return dict(self._private_metadata) @abc.abstractmethod def _metadata(self) -> dict[str, Any]: diff --git a/src/tracksdata/graph/_rustworkx_graph.py b/src/tracksdata/graph/_rustworkx_graph.py index a15dd1f5..a415b89d 100644 --- a/src/tracksdata/graph/_rustworkx_graph.py +++ b/src/tracksdata/graph/_rustworkx_graph.py @@ -343,7 +343,7 @@ def __init__(self, rx_graph: rx.PyDiGraph | None = None) -> None: self._time_to_nodes: dict[int, list[int]] = {} self.__node_attr_schemas: dict[str, AttrSchema] = {} self.__edge_attr_schemas: dict[str, AttrSchema] = {} - self._overlaps: list[list[int, 2]] = [] + self._overlaps: list[list[int]] = [] # Add default node attributes with inferred schemas self.__node_attr_schemas[DEFAULT_ATTR_KEYS.T] = AttrSchema( @@ -400,13 +400,11 @@ def __init__(self, rx_graph: rx.PyDiGraph | None = None) -> None: for key, value in first_node_attrs.items(): if key == DEFAULT_ATTR_KEYS.NODE_ID: continue - dtype = self._attr_dtype_from_metadata(key=key, is_node=True) - if dtype is None: - try: - dtype = pl.Series([value]).dtype - except (ValueError, TypeError): - # If polars can't infer dtype (e.g., for complex objects), use Object - dtype = pl.Object + try: + dtype = pl.Series([value]).dtype + except (ValueError, TypeError): + # If polars can't infer dtype (e.g., for complex objects), use Object + dtype = pl.Object self.__node_attr_schemas[key] = AttrSchema(key=key, dtype=dtype) # Process edges: set edge IDs and infer schemas @@ -424,17 +422,13 @@ def __init__(self, rx_graph: rx.PyDiGraph | None = None) -> None: # TODO: check if EDGE_SOURCE and EDGE_TARGET should be also ignored or in the schema if key == DEFAULT_ATTR_KEYS.EDGE_ID: continue - dtype = self._attr_dtype_from_metadata(key=key, is_node=False) - if dtype is None: - try: - dtype = pl.Series([value]).dtype - except (ValueError, TypeError): - # If polars can't infer dtype (e.g., for complex objects), use Object - dtype = pl.Object + try: + dtype = pl.Series([value]).dtype + except (ValueError, TypeError): + # If polars can't infer dtype (e.g., for complex objects), use Object + dtype = pl.Object self.__edge_attr_schemas[key] = AttrSchema(key=key, dtype=dtype) - self._sync_attr_dtype_metadata() - def _node_attr_schemas(self) -> dict[str, AttrSchema]: return self.__node_attr_schemas @@ -992,7 +986,6 @@ def add_node_attr_key( # Store schema self.__node_attr_schemas[schema.key] = schema - self._set_attr_dtype_metadata(key=schema.key, dtype=schema.dtype, is_node=True) def remove_node_attr_key(self, key: str) -> None: """ @@ -1005,7 +998,6 @@ def remove_node_attr_key(self, key: str) -> None: raise ValueError(f"Cannot remove required node attribute key {key}") del self.__node_attr_schemas[key] - self._remove_attr_dtype_metadata(key=key, is_node=True) for node_attr in self.rx_graph.nodes(): node_attr.pop(key, None) @@ -1034,7 +1026,6 @@ def add_edge_attr_key( # Store schema self.__edge_attr_schemas[schema.key] = schema - self._set_attr_dtype_metadata(key=schema.key, dtype=schema.dtype, is_node=False) def remove_edge_attr_key(self, key: str) -> None: """ @@ -1044,7 +1035,6 @@ def remove_edge_attr_key(self, key: str) -> None: raise ValueError(f"Edge attribute key {key} does not exist") del self.__edge_attr_schemas[key] - self._remove_attr_dtype_metadata(key=key, is_node=False) for edge_attr in self.rx_graph.edges(): edge_attr.pop(key, None) diff --git a/src/tracksdata/graph/_sql_graph.py b/src/tracksdata/graph/_sql_graph.py index 735a10d3..b9da5c1a 100644 --- a/src/tracksdata/graph/_sql_graph.py +++ b/src/tracksdata/graph/_sql_graph.py @@ -20,8 +20,10 @@ from tracksdata.utils._dataframe import unpack_array_attrs, unpickle_bytes_columns from tracksdata.utils._dtypes import ( AttrSchema, + deserialize_attr_schema, polars_dtype_to_sqlalchemy_type, process_attr_key_args, + serialize_attr_schema, sqlalchemy_type_to_polars_dtype, ) from tracksdata.utils._logging import LOG @@ -441,6 +443,7 @@ class SQLGraph(BaseGraph): """ node_id_time_multiplier: int = 1_000_000_000 + _PRIVATE_SQL_SCHEMA_STORE_KEY = "__private_sql_attr_schema_store" Base: type[DeclarativeBase] Node: type[DeclarativeBase] Edge: type[DeclarativeBase] @@ -469,8 +472,6 @@ def __init__( # Create unique classes for this instance self._define_schema(overwrite=overwrite) - self.__node_attr_schemas: dict[str, AttrSchema] = {} - self.__edge_attr_schemas: dict[str, AttrSchema] = {} if overwrite: self.Base.metadata.drop_all(self._engine) @@ -479,7 +480,6 @@ def __init__( # Initialize schemas from existing table columns self._init_schemas_from_tables() - self._sync_attr_dtype_metadata() self._max_id_per_time = {} self._update_max_id_per_time() @@ -552,43 +552,109 @@ class Metadata(Base): self.Overlap = Overlap self.Metadata = Metadata + @classmethod + def _empty_attr_schema_store(cls) -> dict[str, dict[str, str]]: + return {"node": {}, "edge": {}} + + def _attr_schema_store(self) -> dict[str, dict[str, str]]: + store = self._private_metadata.get(self._PRIVATE_SQL_SCHEMA_STORE_KEY, {}) + if not isinstance(store, dict): + return self._empty_attr_schema_store() + + normalized = self._empty_attr_schema_store() + for section_key in ("node", "edge"): + section = store.get(section_key, {}) + if not isinstance(section, dict): + continue + for key, encoded_schema in section.items(): + if isinstance(encoded_schema, str): + normalized[section_key][key] = encoded_schema + + return normalized + + def _set_attr_schema_store(self, store: dict[str, dict[str, str]]) -> None: + normalized = self._empty_attr_schema_store() + for section_key in ("node", "edge"): + section = store.get(section_key, {}) + if not isinstance(section, dict): + continue + for key, encoded_schema in section.items(): + if isinstance(encoded_schema, str): + normalized[section_key][key] = encoded_schema + + self._private_metadata.update(**{self._PRIVATE_SQL_SCHEMA_STORE_KEY: normalized}) + + def _get_attr_schemas_from_store(self, *, is_node: bool) -> dict[str, AttrSchema]: + section_key = "node" if is_node else "edge" + section = self._attr_schema_store()[section_key] + + schemas: dict[str, AttrSchema] = {} + for key, encoded_schema in section.items(): + try: + schemas[key] = deserialize_attr_schema(encoded_schema, key=key) + except Exception: + LOG.warning( + "Failed to deserialize SQL schema metadata for key '%s'. Falling back to table inference.", + key, + ) + + return schemas + + def _set_attr_schemas_to_store(self, *, is_node: bool, schemas: dict[str, AttrSchema]) -> None: + section_key = "node" if is_node else "edge" + store = self._attr_schema_store() + store[section_key] = {key: serialize_attr_schema(schema) for key, schema in schemas.items()} + self._set_attr_schema_store(store) + + @property + def __node_attr_schemas(self) -> dict[str, AttrSchema]: + return self._get_attr_schemas_from_store(is_node=True) + + @__node_attr_schemas.setter + def __node_attr_schemas(self, schemas: dict[str, AttrSchema]) -> None: + self._set_attr_schemas_to_store(is_node=True, schemas=schemas) + + @property + def __edge_attr_schemas(self) -> dict[str, AttrSchema]: + return self._get_attr_schemas_from_store(is_node=False) + + @__edge_attr_schemas.setter + def __edge_attr_schemas(self, schemas: dict[str, AttrSchema]) -> None: + self._set_attr_schemas_to_store(is_node=False, schemas=schemas) + def _init_schemas_from_tables(self) -> None: """ Initialize AttrSchema objects from existing database table columns. This is used when loading an existing graph from the database. """ - node_column_names = list(self.Node.__table__.columns.keys()) preferred_node_order = [DEFAULT_ATTR_KEYS.T, DEFAULT_ATTR_KEYS.NODE_ID] ordered_node_columns = [name for name in preferred_node_order if name in node_column_names] ordered_node_columns.extend(name for name in node_column_names if name not in preferred_node_order) - # Initialize node schemas from Node table columns + node_schemas = {k: v for k, v in self.__node_attr_schemas.items() if k in ordered_node_columns} for column_name in ordered_node_columns: - if column_name not in self.__node_attr_schemas: - pl_dtype = self._attr_dtype_from_metadata(key=column_name, is_node=True) - if pl_dtype is None: - column = self.Node.__table__.columns[column_name] - pl_dtype = sqlalchemy_type_to_polars_dtype(column.type) - # AttrSchema.__post_init__ will infer the default_value - self.__node_attr_schemas[column_name] = AttrSchema( - key=column_name, - dtype=pl_dtype, - ) + if column_name in node_schemas: + continue + column = self.Node.__table__.columns[column_name] + node_schemas[column_name] = AttrSchema( + key=column_name, + dtype=sqlalchemy_type_to_polars_dtype(column.type), + ) + self.__node_attr_schemas = node_schemas # Initialize edge schemas from Edge table columns + edge_column_names = list(self.Edge.__table__.columns.keys()) + edge_schemas = {k: v for k, v in self.__edge_attr_schemas.items() if k in edge_column_names} for column_name in self.Edge.__table__.columns.keys(): - # Skip internal edge columns - if column_name not in self.__edge_attr_schemas: - pl_dtype = self._attr_dtype_from_metadata(key=column_name, is_node=False) - if pl_dtype is None: - column = self.Edge.__table__.columns[column_name] - pl_dtype = sqlalchemy_type_to_polars_dtype(column.type) - # AttrSchema.__post_init__ will infer the default_value - self.__edge_attr_schemas[column_name] = AttrSchema( - key=column_name, - dtype=pl_dtype, - ) + if column_name in edge_schemas: + continue + column = self.Edge.__table__.columns[column_name] + edge_schemas[column_name] = AttrSchema( + key=column_name, + dtype=sqlalchemy_type_to_polars_dtype(column.type), + ) + self.__edge_attr_schemas = edge_schemas def _restore_pickled_column_types(self, table: sa.Table) -> None: for column in table.columns: @@ -1648,15 +1714,14 @@ def add_node_attr_key( dtype: pl.DataType | None = None, default_value: Any = None, ) -> None: + node_schemas = self.__node_attr_schemas # Process arguments and create validated schema - schema = process_attr_key_args(key_or_schema, dtype, default_value, self.__node_attr_schemas) - - # Store schema - self.__node_attr_schemas[schema.key] = schema + schema = process_attr_key_args(key_or_schema, dtype, default_value, node_schemas) # Add column to database self._add_new_column(self.Node, schema) - self._set_attr_dtype_metadata(key=schema.key, dtype=schema.dtype, is_node=True) + node_schemas[schema.key] = schema + self.__node_attr_schemas = node_schemas def remove_node_attr_key(self, key: str) -> None: if key not in self.node_attr_keys(): @@ -1665,9 +1730,10 @@ def remove_node_attr_key(self, key: str) -> None: if key in (DEFAULT_ATTR_KEYS.NODE_ID, DEFAULT_ATTR_KEYS.T): raise ValueError(f"Cannot remove required node attribute key {key}") + node_schemas = self.__node_attr_schemas self._drop_column(self.Node, key) - self.__node_attr_schemas.pop(key, None) - self._remove_attr_dtype_metadata(key=key, is_node=True) + node_schemas.pop(key, None) + self.__node_attr_schemas = node_schemas def add_edge_attr_key( self, @@ -1675,23 +1741,23 @@ def add_edge_attr_key( dtype: pl.DataType | None = None, default_value: Any = None, ) -> None: + edge_schemas = self.__edge_attr_schemas # Process arguments and create validated schema - schema = process_attr_key_args(key_or_schema, dtype, default_value, self.__edge_attr_schemas) - - # Store schema - self.__edge_attr_schemas[schema.key] = schema + schema = process_attr_key_args(key_or_schema, dtype, default_value, edge_schemas) # Add column to database self._add_new_column(self.Edge, schema) - self._set_attr_dtype_metadata(key=schema.key, dtype=schema.dtype, is_node=False) + edge_schemas[schema.key] = schema + self.__edge_attr_schemas = edge_schemas def remove_edge_attr_key(self, key: str) -> None: if key not in self.edge_attr_keys(): raise ValueError(f"Edge attribute key {key} does not exist") + edge_schemas = self.__edge_attr_schemas self._drop_column(self.Edge, key) - self.__edge_attr_schemas.pop(key, None) - self._remove_attr_dtype_metadata(key=key, is_node=False) + edge_schemas.pop(key, None) + self.__edge_attr_schemas = edge_schemas def num_edges(self) -> int: with Session(self._engine) as session: @@ -2032,6 +2098,11 @@ def _metadata(self) -> dict[str, Any]: result = session.query(self.Metadata).all() return {row.key: row.value for row in result} + def _private_metadata_for_copy(self) -> dict[str, Any]: + private_metadata = super()._private_metadata_for_copy() + private_metadata.pop(self._PRIVATE_SQL_SCHEMA_STORE_KEY, None) + return private_metadata + def _update_metadata(self, **kwargs) -> None: with Session(self._engine) as session: for key, value in kwargs.items(): diff --git a/src/tracksdata/graph/_test/test_graph_backends.py b/src/tracksdata/graph/_test/test_graph_backends.py index 0bc7dcf9..99ee386e 100644 --- a/src/tracksdata/graph/_test/test_graph_backends.py +++ b/src/tracksdata/graph/_test/test_graph_backends.py @@ -1704,6 +1704,70 @@ def test_sql_graph_max_id_restored_per_timepoint(tmp_path: Path) -> None: assert next_id == first_id + 1 +def test_sql_graph_schema_defaults_survive_reload(tmp_path: Path) -> None: + """Reloading a SQLGraph should preserve dtype and default schema metadata.""" + db_path = tmp_path / "schema_defaults.db" + graph = SQLGraph("sqlite", str(db_path)) + + node_array_default = np.array([1.0, 2.0, 3.0], dtype=np.float32) + node_object_default = {"nested": [1, 2, 3]} + edge_score_default = 0.25 + + graph.add_node_attr_key("node_array_default", pl.Array(pl.Float32, 3), node_array_default) + graph.add_node_attr_key("node_object_default", pl.Object, node_object_default) + graph.add_edge_attr_key("edge_score_default", pl.Float32, edge_score_default) + graph._engine.dispose() + + reloaded = SQLGraph("sqlite", str(db_path)) + + node_schemas = reloaded._node_attr_schemas() + edge_schemas = reloaded._edge_attr_schemas() + np.testing.assert_array_equal(node_schemas["node_array_default"].default_value, node_array_default) + assert node_schemas["node_array_default"].dtype == pl.Array(pl.Float32, 3) + assert node_schemas["node_object_default"].default_value == node_object_default + assert node_schemas["node_object_default"].dtype == pl.Object + assert edge_schemas["edge_score_default"].default_value == edge_score_default + assert edge_schemas["edge_score_default"].dtype == pl.Float32 + + +def test_sql_schema_metadata_not_copied_to_in_memory_graphs() -> None: + """SQL-private schema metadata should not leak into in-memory backends via from_other.""" + sql_graph = SQLGraph("sqlite", ":memory:") + sql_graph.add_node_attr_key("node_array_default", pl.Array(pl.Float32, 3), np.array([1.0, 2.0, 3.0], np.float32)) + sql_graph.add_node_attr_key("node_object_default", pl.Object, {"payload": [1, 2, 3]}) + sql_graph.add_edge_attr_key("edge_score_default", pl.Float32, 0.25) + + n1 = sql_graph.add_node( + { + "t": 0, + "node_array_default": np.array([1.0, 1.0, 1.0], dtype=np.float32), + "node_object_default": {"payload": [10]}, + } + ) + n2 = sql_graph.add_node( + { + "t": 1, + "node_array_default": np.array([2.0, 2.0, 2.0], dtype=np.float32), + "node_object_default": {"payload": [20]}, + } + ) + sql_graph.add_edge(n1, n2, {"edge_score_default": 0.75}) + + assert SQLGraph._PRIVATE_SQL_SCHEMA_STORE_KEY in sql_graph._private_metadata + + rx_graph = RustWorkXGraph.from_other(sql_graph) + assert SQLGraph._PRIVATE_SQL_SCHEMA_STORE_KEY not in rx_graph._metadata() + + sql_graph_roundtrip = SQLGraph.from_other( + rx_graph, + drivername="sqlite", + database=":memory:", + engine_kwargs={"connect_args": {"check_same_thread": False}}, + ) + assert sql_graph_roundtrip._node_attr_schemas() == sql_graph._node_attr_schemas() + assert sql_graph_roundtrip._edge_attr_schemas() == sql_graph._edge_attr_schemas() + + def test_compute_overlaps_invalid_threshold(graph_backend: BaseGraph) -> None: """Test compute_overlaps with invalid threshold values.""" with pytest.raises(ValueError, match=r"iou_threshold must be between 0.0 and 1\.0"): diff --git a/src/tracksdata/utils/_dtypes.py b/src/tracksdata/utils/_dtypes.py index 90fc2006..0245acdc 100644 --- a/src/tracksdata/utils/_dtypes.py +++ b/src/tracksdata/utils/_dtypes.py @@ -507,6 +507,97 @@ def deserialize_polars_dtype(encoded_dtype: str) -> pl.DataType: return restored_df.schema["dummy"] +_ATTR_SCHEMA_DTYPE_COL = "__attr_schema_dtype__" +_ATTR_SCHEMA_DEFAULT_COL = "__attr_schema_default__" +_ATTR_SCHEMA_DTYPE_PICKLE_COL = "__attr_schema_dtype_pickle__" + + +def serialize_attr_schema(schema: AttrSchema) -> str: + """ + Serialize an AttrSchema into a base64-encoded Arrow IPC payload. + + The payload stores dtype metadata and the default value in the same + DataFrame serialization so schema roundtrip can restore both fields. + """ + default_payload = dumps(schema.default_value) + dtype_payload = dumps(schema.dtype) + df = pl.DataFrame( + { + _ATTR_SCHEMA_DTYPE_COL: pl.Series( + _ATTR_SCHEMA_DTYPE_COL, + values=[None], + dtype=schema.dtype, + ), + _ATTR_SCHEMA_DEFAULT_COL: pl.Series( + _ATTR_SCHEMA_DEFAULT_COL, + values=[default_payload], + dtype=pl.Binary, + ), + _ATTR_SCHEMA_DTYPE_PICKLE_COL: pl.Series( + _ATTR_SCHEMA_DTYPE_PICKLE_COL, + values=[dtype_payload], + dtype=pl.Binary, + ), + } + ) + + buffer = io.BytesIO() + try: + df.write_ipc(buffer) + except Exception: + # Fallback for dtypes that cannot be represented in Arrow IPC schema + # (e.g., pl.Object). Keep everything in one DataFrame payload. + fallback_df = pl.DataFrame( + { + _ATTR_SCHEMA_DTYPE_COL: pl.Series( + _ATTR_SCHEMA_DTYPE_COL, + values=[None], + dtype=pl.Binary, + ), + _ATTR_SCHEMA_DEFAULT_COL: pl.Series( + _ATTR_SCHEMA_DEFAULT_COL, + values=[default_payload], + dtype=pl.Binary, + ), + _ATTR_SCHEMA_DTYPE_PICKLE_COL: pl.Series( + _ATTR_SCHEMA_DTYPE_PICKLE_COL, + values=[dtype_payload], + dtype=pl.Binary, + ), + } + ) + buffer = io.BytesIO() + fallback_df.write_ipc(buffer) + + return base64.b64encode(buffer.getvalue()).decode("utf-8") + + +def deserialize_attr_schema(encoded_schema: str, *, key: str) -> AttrSchema: + """ + Deserialize an AttrSchema previously encoded by `serialize_attr_schema`. + """ + data = base64.b64decode(encoded_schema) + buffer = io.BytesIO(data) + restored_df = pl.read_ipc(buffer) + + if _ATTR_SCHEMA_DTYPE_PICKLE_COL in restored_df.columns: + dtype_pickle = restored_df[_ATTR_SCHEMA_DTYPE_PICKLE_COL][0] + else: + dtype_pickle = None + + if dtype_pickle is not None: + dtype = loads(dtype_pickle) + else: + dtype = restored_df.schema[_ATTR_SCHEMA_DTYPE_COL] + + if not pl.datatypes.is_polars_dtype(dtype): + raise TypeError(f"Decoded value is not a polars dtype: {type(dtype)}") + + default_payload = restored_df[_ATTR_SCHEMA_DEFAULT_COL][0] + default_value = loads(default_payload) if default_payload is not None else None + return AttrSchema(key=key, dtype=dtype, default_value=default_value) + + def validate_default_value_dtype_compatibility(default_value: Any, dtype: pl.DataType) -> None: """ Validate that a default value is compatible with a polars dtype. diff --git a/src/tracksdata/utils/_test/test_dtype_serialization.py b/src/tracksdata/utils/_test/test_dtype_serialization.py index 3a997209..51b2659a 100644 --- a/src/tracksdata/utils/_test/test_dtype_serialization.py +++ b/src/tracksdata/utils/_test/test_dtype_serialization.py @@ -1,10 +1,17 @@ import base64 import binascii +import numpy as np import polars as pl import pytest -from tracksdata.utils._dtypes import deserialize_polars_dtype, serialize_polars_dtype +from tracksdata.utils._dtypes import ( + AttrSchema, + deserialize_attr_schema, + deserialize_polars_dtype, + serialize_attr_schema, + serialize_polars_dtype, +) @pytest.mark.parametrize( @@ -43,3 +50,21 @@ def test_deserialize_polars_dtype_non_ipc_payload_raises() -> None: with pytest.raises((OSError, pl.exceptions.PolarsError)): deserialize_polars_dtype(encoded) + + +@pytest.mark.parametrize( + "schema", + [ + AttrSchema(key="score", dtype=pl.Float64, default_value=1.25), + AttrSchema( + key="vector", + dtype=pl.Array(pl.Float32, 3), + default_value=np.array([1.0, 2.0, 3.0], dtype=np.float32), + ), + AttrSchema(key="payload", dtype=pl.Object, default_value={"nested": [1, 2, 3]}), + ], +) +def test_serialize_deserialize_attr_schema_roundtrip(schema: AttrSchema) -> None: + encoded = serialize_attr_schema(schema) + restored = deserialize_attr_schema(encoded, key=schema.key) + assert restored == schema From 9aa9c3a686754a9c540ffd8bbe7c8507579eeecd Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Wed, 18 Feb 2026 11:03:12 +0900 Subject: [PATCH 08/30] udpated serialization strategies --- .../graph/_test/test_graph_backends.py | 3 +- src/tracksdata/utils/_dtypes.py | 113 +++++++----------- .../utils/_test/test_dtype_serialization.py | 33 +++-- 3 files changed, 68 insertions(+), 81 deletions(-) diff --git a/src/tracksdata/graph/_test/test_graph_backends.py b/src/tracksdata/graph/_test/test_graph_backends.py index 99ee386e..63013a4d 100644 --- a/src/tracksdata/graph/_test/test_graph_backends.py +++ b/src/tracksdata/graph/_test/test_graph_backends.py @@ -1,3 +1,4 @@ +import datetime as dt from pathlib import Path from typing import Any @@ -1493,7 +1494,7 @@ def test_from_other_preserves_schema_roundtrip(target_cls: type[BaseGraph], targ "attr_UInt32": np.uint32(10), "attr_UInt64": np.uint64(11), "attr_Date": pl.date(2024, 1, 1), - "attr_Datetime": pl.datetime(2024, 1, 1, 12, 0, 0), + "attr_Datetime": dt.datetime(2024, 1, 1, 12, 0, 0), "attr_Boolean": True, "attr_Array(Float32, shape=(3,))": np.array([1.0, 2.0, 3.0], dtype=np.float32), "attr_List(Int32)": [1, 2, 3], diff --git a/src/tracksdata/utils/_dtypes.py b/src/tracksdata/utils/_dtypes.py index 0245acdc..05338fa5 100644 --- a/src/tracksdata/utils/_dtypes.py +++ b/src/tracksdata/utils/_dtypes.py @@ -478,66 +478,48 @@ def sqlalchemy_type_to_polars_dtype(sa_type: TypeEngine) -> pl.DataType: return pl.Object -def serialize_polars_dtype(dtype: pl.DataType) -> str: - """ - Serializes a Polars dtype to a safe, cross-platform base64 string - using the Arrow IPC format. - """ - # Wrap the dtype in an empty DataFrame schema - # We use an empty DataFrame so no actual data is processed, only metadata. - dummy_df = pl.DataFrame(schema={"dummy": dtype}) - # Write to Arrow IPC (binary buffer) - # IPC is stable across versions/platforms unlike internal serialization. - buffer = io.BytesIO() - dummy_df.write_ipc(buffer) - # Encode binary to a standard string - return base64.b64encode(buffer.getvalue()).decode("utf-8") +def _normalize_default_for_dtype(default_value: Any, dtype: pl.DataType) -> Any: + if isinstance(dtype, pl.Array | pl.List) and isinstance(default_value, np.ndarray): + return default_value.tolist() + return default_value -def deserialize_polars_dtype(encoded_dtype: str) -> pl.DataType: - """ - Recovers a Polars dtype from a base64 string. - """ - # Decode string back to binary - data = base64.b64decode(encoded_dtype) - # Read the IPC buffer - buffer = io.BytesIO(data) - restored_df = pl.read_ipc(buffer) - # Extract the dtype from the schema - return restored_df.schema["dummy"] +def _normalize_deserialized_default(default_value: Any, dtype: pl.DataType) -> Any: + if isinstance(dtype, pl.Array): + if isinstance(default_value, pl.Series): + default_value = default_value.to_list() + numpy_dtype = polars_dtype_to_numpy_dtype(dtype.inner, allow_sequence=True) + return np.asarray(default_value, dtype=numpy_dtype).reshape(dtype.shape) + + if isinstance(dtype, pl.List): + if isinstance(default_value, pl.Series): + return default_value.to_list() + if isinstance(default_value, np.ndarray): + return default_value.tolist() + + return default_value -_ATTR_SCHEMA_DTYPE_COL = "__attr_schema_dtype__" -_ATTR_SCHEMA_DEFAULT_COL = "__attr_schema_default__" -_ATTR_SCHEMA_DTYPE_PICKLE_COL = "__attr_schema_dtype_pickle__" +_ATTR_SCHEMA_VALUE_COL = "__attr_schema_value__" +_ATTR_SCHEMA_FALLBACK_COL = "__attr_schema_fallback__" def serialize_attr_schema(schema: AttrSchema) -> str: """ Serialize an AttrSchema into a base64-encoded Arrow IPC payload. - The payload stores dtype metadata and the default value in the same - DataFrame serialization so schema roundtrip can restore both fields. + The primary format stores schema.default_value in the first row of a + single dummy column whose dtype is schema.dtype. This keeps dtype and + default value in one Arrow IPC payload. """ - default_payload = dumps(schema.default_value) - dtype_payload = dumps(schema.dtype) + normalized_default = _normalize_default_for_dtype(schema.default_value, schema.dtype) df = pl.DataFrame( { - _ATTR_SCHEMA_DTYPE_COL: pl.Series( - _ATTR_SCHEMA_DTYPE_COL, - values=[None], + _ATTR_SCHEMA_VALUE_COL: pl.Series( + _ATTR_SCHEMA_VALUE_COL, + values=[normalized_default], dtype=schema.dtype, ), - _ATTR_SCHEMA_DEFAULT_COL: pl.Series( - _ATTR_SCHEMA_DEFAULT_COL, - values=[default_payload], - dtype=pl.Binary, - ), - _ATTR_SCHEMA_DTYPE_PICKLE_COL: pl.Series( - _ATTR_SCHEMA_DTYPE_PICKLE_COL, - values=[dtype_payload], - dtype=pl.Binary, - ), } ) @@ -545,23 +527,14 @@ def serialize_attr_schema(schema: AttrSchema) -> str: try: df.write_ipc(buffer) except Exception: - # Fallback for dtypes that cannot be represented in Arrow IPC schema - # (e.g., pl.Object). Keep everything in one DataFrame payload. + # Some dtypes (e.g. pl.Object) cannot roundtrip through Arrow IPC schema. + # Store pickled (dtype, default) in the first row of a binary dummy column. + fallback_payload = dumps((schema.dtype, schema.default_value)) fallback_df = pl.DataFrame( { - _ATTR_SCHEMA_DTYPE_COL: pl.Series( - _ATTR_SCHEMA_DTYPE_COL, - values=[None], - dtype=pl.Binary, - ), - _ATTR_SCHEMA_DEFAULT_COL: pl.Series( - _ATTR_SCHEMA_DEFAULT_COL, - values=[default_payload], - dtype=pl.Binary, - ), - _ATTR_SCHEMA_DTYPE_PICKLE_COL: pl.Series( - _ATTR_SCHEMA_DTYPE_PICKLE_COL, - values=[dtype_payload], + _ATTR_SCHEMA_FALLBACK_COL: pl.Series( + _ATTR_SCHEMA_FALLBACK_COL, + values=[fallback_payload], dtype=pl.Binary, ), } @@ -580,21 +553,21 @@ def deserialize_attr_schema(encoded_schema: str, *, key: str) -> AttrSchema: buffer = io.BytesIO(data) restored_df = pl.read_ipc(buffer) - if _ATTR_SCHEMA_DTYPE_PICKLE_COL in restored_df.columns: - dtype_pickle = restored_df[_ATTR_SCHEMA_DTYPE_PICKLE_COL][0] - else: - dtype_pickle = None - - if dtype_pickle is not None: - dtype = loads(dtype_pickle) + if _ATTR_SCHEMA_VALUE_COL in restored_df.columns: + dtype = restored_df.schema[_ATTR_SCHEMA_VALUE_COL] + default_value = restored_df[_ATTR_SCHEMA_VALUE_COL][0] + elif _ATTR_SCHEMA_FALLBACK_COL in restored_df.columns: + fallback_payload = restored_df[_ATTR_SCHEMA_FALLBACK_COL][0] + if fallback_payload is None: + raise ValueError("Fallback schema payload is missing.") + dtype, default_value = loads(fallback_payload) else: - dtype = restored_df.schema[_ATTR_SCHEMA_DTYPE_COL] + raise ValueError("Unrecognized attr schema payload format.") if not pl.datatypes.is_polars_dtype(dtype): raise TypeError(f"Decoded value is not a polars dtype: {type(dtype)}") - default_payload = restored_df[_ATTR_SCHEMA_DEFAULT_COL][0] - default_value = loads(default_payload) if default_payload is not None else None + default_value = _normalize_deserialized_default(default_value, dtype) return AttrSchema(key=key, dtype=dtype, default_value=default_value) diff --git a/src/tracksdata/utils/_test/test_dtype_serialization.py b/src/tracksdata/utils/_test/test_dtype_serialization.py index 51b2659a..1f406224 100644 --- a/src/tracksdata/utils/_test/test_dtype_serialization.py +++ b/src/tracksdata/utils/_test/test_dtype_serialization.py @@ -1,5 +1,6 @@ import base64 import binascii +import io import numpy as np import polars as pl @@ -8,9 +9,7 @@ from tracksdata.utils._dtypes import ( AttrSchema, deserialize_attr_schema, - deserialize_polars_dtype, serialize_attr_schema, - serialize_polars_dtype, ) @@ -28,28 +27,29 @@ pl.Datetime("us", "UTC"), ], ) -def test_serialize_deserialize_polars_dtype_roundtrip(dtype: pl.DataType) -> None: - encoded = serialize_polars_dtype(dtype) +def test_serialize_deserialize_attr_schema_dtype_roundtrip(dtype: pl.DataType) -> None: + schema = AttrSchema(key="dummy", dtype=dtype) + encoded = serialize_attr_schema(schema) assert isinstance(encoded, str) assert encoded assert base64.b64decode(encoded) - restored_dtype = deserialize_polars_dtype(encoded) + restored = deserialize_attr_schema(encoded, key=schema.key) - assert restored_dtype == dtype + assert restored == schema -def test_deserialize_polars_dtype_invalid_base64_raises() -> None: +def test_deserialize_attr_schema_invalid_base64_raises() -> None: with pytest.raises(binascii.Error): - deserialize_polars_dtype("not-base64") + deserialize_attr_schema("not-base64", key="dummy") -def test_deserialize_polars_dtype_non_ipc_payload_raises() -> None: +def test_deserialize_attr_schema_non_ipc_payload_raises() -> None: encoded = base64.b64encode(b"not-arrow-ipc").decode("utf-8") with pytest.raises((OSError, pl.exceptions.PolarsError)): - deserialize_polars_dtype(encoded) + deserialize_attr_schema(encoded, key="dummy") @pytest.mark.parametrize( @@ -68,3 +68,16 @@ def test_serialize_deserialize_attr_schema_roundtrip(schema: AttrSchema) -> None encoded = serialize_attr_schema(schema) restored = deserialize_attr_schema(encoded, key=schema.key) assert restored == schema + + +def test_serialize_attr_schema_stores_default_in_dummy_row() -> None: + schema = AttrSchema(key="score", dtype=pl.Float64, default_value=1.25) + encoded = serialize_attr_schema(schema) + + payload = base64.b64decode(encoded) + df = pl.read_ipc(io.BytesIO(payload)) + + assert "__attr_schema_value__" in df.columns + assert df.schema["__attr_schema_value__"] == pl.Float64 + assert df["__attr_schema_value__"][0] == 1.25 + assert "__attr_schema_dtype_pickle__" not in df.columns From 7e61ac33ee6ec0976c01eed1e084fa86c2ee7bcc Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Wed, 18 Feb 2026 11:11:10 +0900 Subject: [PATCH 09/30] solved failing tests --- src/tracksdata/solvers/_ilp_solver.py | 9 ++++++++- src/tracksdata/solvers/_nearest_neighbors_solver.py | 3 ++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/tracksdata/solvers/_ilp_solver.py b/src/tracksdata/solvers/_ilp_solver.py index 6485eaf7..3f6676d5 100644 --- a/src/tracksdata/solvers/_ilp_solver.py +++ b/src/tracksdata/solvers/_ilp_solver.py @@ -175,6 +175,9 @@ def _evaluate_expr( expr: Attr, df: pl.DataFrame, ) -> list[float]: + if df.is_empty(): + return [] + if len(expr.expr_columns) == 0: return [expr.evaluate(df).item()] * len(df) else: @@ -388,7 +391,11 @@ def solve( node_attr_keys.extend(self.merge_weight_expr.columns) nodes_df = graph.node_attrs(attr_keys=node_attr_keys) - edges_df = graph.edge_attrs(attr_keys=self.edge_weight_expr.columns) + # When no edges exist, avoid requesting edge weight columns that may not + # be registered in the backend schema yet. _solve() handles this as a + # regular "no edges" ValueError. + edge_attr_keys = [] if graph.num_edges() == 0 else self.edge_weight_expr.columns + edges_df = graph.edge_attrs(attr_keys=edge_attr_keys) self._add_objective_and_variables(nodes_df, edges_df) self._add_continuous_flow_constraints(nodes_df[DEFAULT_ATTR_KEYS.NODE_ID].to_list(), edges_df) diff --git a/src/tracksdata/solvers/_nearest_neighbors_solver.py b/src/tracksdata/solvers/_nearest_neighbors_solver.py index 34011dee..21915290 100644 --- a/src/tracksdata/solvers/_nearest_neighbors_solver.py +++ b/src/tracksdata/solvers/_nearest_neighbors_solver.py @@ -235,7 +235,8 @@ def solve( The graph view of the solution if `return_solution` is True, otherwise None. """ # get edges and sort them by weight - edges_df = graph.edge_attrs(attr_keys=self.edge_weight_expr.columns) + edge_attr_keys = [] if graph.num_edges() == 0 else self.edge_weight_expr.columns + edges_df = graph.edge_attrs(attr_keys=edge_attr_keys) if len(edges_df) == 0: raise ValueError("No edges found in the graph, there is nothing to solve.") From e5968bf44aa9697756df354b97f005f05c2a2bd6 Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Wed, 18 Feb 2026 11:45:10 +0900 Subject: [PATCH 10/30] added test for shape-less pl.Array (xfail) --- src/tracksdata/graph/_sql_graph.py | 8 +------- .../graph/_test/test_graph_backends.py | 17 +++++++++++++++++ 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/src/tracksdata/graph/_sql_graph.py b/src/tracksdata/graph/_sql_graph.py index b9da5c1a..2d0ea3ad 100644 --- a/src/tracksdata/graph/_sql_graph.py +++ b/src/tracksdata/graph/_sql_graph.py @@ -590,13 +590,7 @@ def _get_attr_schemas_from_store(self, *, is_node: bool) -> dict[str, AttrSchema schemas: dict[str, AttrSchema] = {} for key, encoded_schema in section.items(): - try: - schemas[key] = deserialize_attr_schema(encoded_schema, key=key) - except Exception: - LOG.warning( - "Failed to deserialize SQL schema metadata for key '%s'. Falling back to table inference.", - key, - ) + schemas[key] = deserialize_attr_schema(encoded_schema, key=key) return schemas diff --git a/src/tracksdata/graph/_test/test_graph_backends.py b/src/tracksdata/graph/_test/test_graph_backends.py index 63013a4d..f60d9641 100644 --- a/src/tracksdata/graph/_test/test_graph_backends.py +++ b/src/tracksdata/graph/_test/test_graph_backends.py @@ -1523,6 +1523,23 @@ def test_from_other_preserves_schema_roundtrip(target_cls: type[BaseGraph], targ assert graph3.edge_attrs().schema == graph.edge_attrs().schema +@pytest.mark.xfail(reason="This is because of the lack of support of shape-less pl.Array in write_ipc of polars.") +def test_from_other_with_array_no_shape(): + """Test that from_other raises an error when trying to copy array attributes without shape information.""" + graph = RustWorkXGraph() + graph.add_node_attr_key("array_attr", pl.Array) + graph.add_node({"t": 0, "array_attr": np.array([1.0, 2.0, 3.0], dtype=np.float32)}) + + # This should raise an error because the schema does not include shape information + graph2 = SQLGraph.from_other( + graph, drivername="sqlite", database=":memory:", engine_kwargs={"connect_args": {"check_same_thread": False}} + ) + assert graph2.num_nodes() == graph.num_nodes() + assert set(graph2.node_attr_keys()) == set(graph.node_attr_keys()) + assert graph2._node_attr_schemas() == graph._node_attr_schemas() + assert graph2.node_attrs().schema == graph.node_attrs().schema + + @pytest.mark.parametrize( ("target_cls", "target_kwargs"), [ From b4acde3e8985b7eacd5a24328f3d4840dc49889c Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Thu, 19 Feb 2026 10:43:59 +0900 Subject: [PATCH 11/30] working --- src/tracksdata/graph/_sql_graph.py | 65 ++++--------------- .../graph/_test/test_graph_backends.py | 6 +- 2 files changed, 16 insertions(+), 55 deletions(-) diff --git a/src/tracksdata/graph/_sql_graph.py b/src/tracksdata/graph/_sql_graph.py index 2d0ea3ad..20ea59a9 100644 --- a/src/tracksdata/graph/_sql_graph.py +++ b/src/tracksdata/graph/_sql_graph.py @@ -443,7 +443,8 @@ class SQLGraph(BaseGraph): """ node_id_time_multiplier: int = 1_000_000_000 - _PRIVATE_SQL_SCHEMA_STORE_KEY = "__private_sql_attr_schema_store" + _PRIVATE_SQL_NODE_SCHEMA_STORE_KEY = "__private_sql_node_attr_schema_store" + _PRIVATE_SQL_EDGE_SCHEMA_STORE_KEY = "__private_sql_edge_attr_schema_store" Base: type[DeclarativeBase] Node: type[DeclarativeBase] Edge: type[DeclarativeBase] @@ -552,69 +553,26 @@ class Metadata(Base): self.Overlap = Overlap self.Metadata = Metadata - @classmethod - def _empty_attr_schema_store(cls) -> dict[str, dict[str, str]]: - return {"node": {}, "edge": {}} - - def _attr_schema_store(self) -> dict[str, dict[str, str]]: - store = self._private_metadata.get(self._PRIVATE_SQL_SCHEMA_STORE_KEY, {}) - if not isinstance(store, dict): - return self._empty_attr_schema_store() - - normalized = self._empty_attr_schema_store() - for section_key in ("node", "edge"): - section = store.get(section_key, {}) - if not isinstance(section, dict): - continue - for key, encoded_schema in section.items(): - if isinstance(encoded_schema, str): - normalized[section_key][key] = encoded_schema - - return normalized - - def _set_attr_schema_store(self, store: dict[str, dict[str, str]]) -> None: - normalized = self._empty_attr_schema_store() - for section_key in ("node", "edge"): - section = store.get(section_key, {}) - if not isinstance(section, dict): - continue - for key, encoded_schema in section.items(): - if isinstance(encoded_schema, str): - normalized[section_key][key] = encoded_schema - - self._private_metadata.update(**{self._PRIVATE_SQL_SCHEMA_STORE_KEY: normalized}) - - def _get_attr_schemas_from_store(self, *, is_node: bool) -> dict[str, AttrSchema]: - section_key = "node" if is_node else "edge" - section = self._attr_schema_store()[section_key] - - schemas: dict[str, AttrSchema] = {} - for key, encoded_schema in section.items(): - schemas[key] = deserialize_attr_schema(encoded_schema, key=key) - - return schemas - - def _set_attr_schemas_to_store(self, *, is_node: bool, schemas: dict[str, AttrSchema]) -> None: - section_key = "node" if is_node else "edge" - store = self._attr_schema_store() - store[section_key] = {key: serialize_attr_schema(schema) for key, schema in schemas.items()} - self._set_attr_schema_store(store) @property def __node_attr_schemas(self) -> dict[str, AttrSchema]: - return self._get_attr_schemas_from_store(is_node=True) + encoded_schemas = self._private_metadata.get(self._PRIVATE_SQL_NODE_SCHEMA_STORE_KEY, {}) + return {key: deserialize_attr_schema(encoded_schema, key=key) for key, encoded_schema in encoded_schemas.items()} @__node_attr_schemas.setter def __node_attr_schemas(self, schemas: dict[str, AttrSchema]) -> None: - self._set_attr_schemas_to_store(is_node=True, schemas=schemas) + encoded_schemas = {key: serialize_attr_schema(schema) for key, schema in schemas.items()} + self._private_metadata[self._PRIVATE_SQL_NODE_SCHEMA_STORE_KEY] = encoded_schemas @property def __edge_attr_schemas(self) -> dict[str, AttrSchema]: - return self._get_attr_schemas_from_store(is_node=False) + encoded_schemas = self._private_metadata.get(self._PRIVATE_SQL_EDGE_SCHEMA_STORE_KEY, {}) + return {key: deserialize_attr_schema(encoded_schema, key=key) for key, encoded_schema in encoded_schemas.items()} @__edge_attr_schemas.setter def __edge_attr_schemas(self, schemas: dict[str, AttrSchema]) -> None: - self._set_attr_schemas_to_store(is_node=False, schemas=schemas) + encoded_schemas = {key: serialize_attr_schema(schema) for key, schema in schemas.items()} + self._private_metadata[self._PRIVATE_SQL_EDGE_SCHEMA_STORE_KEY] = encoded_schemas def _init_schemas_from_tables(self) -> None: """ @@ -2094,7 +2052,8 @@ def _metadata(self) -> dict[str, Any]: def _private_metadata_for_copy(self) -> dict[str, Any]: private_metadata = super()._private_metadata_for_copy() - private_metadata.pop(self._PRIVATE_SQL_SCHEMA_STORE_KEY, None) + private_metadata.pop(self._PRIVATE_SQL_NODE_SCHEMA_STORE_KEY, None) + private_metadata.pop(self._PRIVATE_SQL_EDGE_SCHEMA_STORE_KEY, None) return private_metadata def _update_metadata(self, **kwargs) -> None: diff --git a/src/tracksdata/graph/_test/test_graph_backends.py b/src/tracksdata/graph/_test/test_graph_backends.py index f60d9641..8aa2324f 100644 --- a/src/tracksdata/graph/_test/test_graph_backends.py +++ b/src/tracksdata/graph/_test/test_graph_backends.py @@ -1771,10 +1771,12 @@ def test_sql_schema_metadata_not_copied_to_in_memory_graphs() -> None: ) sql_graph.add_edge(n1, n2, {"edge_score_default": 0.75}) - assert SQLGraph._PRIVATE_SQL_SCHEMA_STORE_KEY in sql_graph._private_metadata + assert SQLGraph._PRIVATE_SQL_NODE_SCHEMA_STORE_KEY in sql_graph._private_metadata + assert SQLGraph._PRIVATE_SQL_EDGE_SCHEMA_STORE_KEY in sql_graph._private_metadata rx_graph = RustWorkXGraph.from_other(sql_graph) - assert SQLGraph._PRIVATE_SQL_SCHEMA_STORE_KEY not in rx_graph._metadata() + assert SQLGraph._PRIVATE_SQL_NODE_SCHEMA_STORE_KEY not in rx_graph._metadata() + assert SQLGraph._PRIVATE_SQL_EDGE_SCHEMA_STORE_KEY not in rx_graph._metadata() sql_graph_roundtrip = SQLGraph.from_other( rx_graph, From cc55976e997cf7c48ffcd750b8bf57220b0ba8c0 Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Thu, 19 Feb 2026 10:48:06 +0900 Subject: [PATCH 12/30] simplified code --- src/tracksdata/graph/_sql_graph.py | 133 +++++++++++++++++------------ 1 file changed, 79 insertions(+), 54 deletions(-) diff --git a/src/tracksdata/graph/_sql_graph.py b/src/tracksdata/graph/_sql_graph.py index 20ea59a9..65784572 100644 --- a/src/tracksdata/graph/_sql_graph.py +++ b/src/tracksdata/graph/_sql_graph.py @@ -479,9 +479,6 @@ def __init__( self.Base.metadata.create_all(self._engine) - # Initialize schemas from existing table columns - self._init_schemas_from_tables() - self._max_id_per_time = {} self._update_max_id_per_time() @@ -553,72 +550,102 @@ class Metadata(Base): self.Overlap = Overlap self.Metadata = Metadata + @staticmethod + def _default_node_attr_schemas() -> dict[str, AttrSchema]: + return { + DEFAULT_ATTR_KEYS.T: AttrSchema(key=DEFAULT_ATTR_KEYS.T, dtype=pl.Int32), + DEFAULT_ATTR_KEYS.NODE_ID: AttrSchema(key=DEFAULT_ATTR_KEYS.NODE_ID, dtype=pl.Int64), + } + + @staticmethod + def _default_edge_attr_schemas() -> dict[str, AttrSchema]: + return { + DEFAULT_ATTR_KEYS.EDGE_ID: AttrSchema(key=DEFAULT_ATTR_KEYS.EDGE_ID, dtype=pl.Int32), + DEFAULT_ATTR_KEYS.EDGE_SOURCE: AttrSchema(key=DEFAULT_ATTR_KEYS.EDGE_SOURCE, dtype=pl.Int64), + DEFAULT_ATTR_KEYS.EDGE_TARGET: AttrSchema(key=DEFAULT_ATTR_KEYS.EDGE_TARGET, dtype=pl.Int64), + } + + def _attr_schemas_from_metadata( + self, + *, + table_class: type[DeclarativeBase], + metadata_key: str, + default_schemas: dict[str, AttrSchema], + preferred_order: Sequence[str], + ) -> dict[str, AttrSchema]: + encoded_schemas = self._private_metadata.get(metadata_key, {}) + schemas = default_schemas.copy() + schemas.update( + {key: deserialize_attr_schema(encoded_schema, key=key) for key, encoded_schema in encoded_schemas.items()} + ) + + # Legacy databases may not have schema metadata for all columns. + for column_name, column in table_class.__table__.columns.items(): + if column_name not in schemas: + schemas[column_name] = AttrSchema( + key=column_name, + dtype=sqlalchemy_type_to_polars_dtype(column.type), + ) + + ordered_keys = [key for key in preferred_order if key in schemas] + ordered_keys.extend(key for key in table_class.__table__.columns.keys() if key not in ordered_keys) + ordered_keys.extend(key for key in schemas if key not in ordered_keys) + return {key: schemas[key] for key in ordered_keys} + + def _attr_schemas_for_table(self, table_class: type[DeclarativeBase]) -> dict[str, AttrSchema]: + if table_class.__tablename__ == self.Node.__tablename__: + return self._node_attr_schemas() + return self._edge_attr_schemas() + + @staticmethod + def _is_pickled_sql_type(column_type: TypeEngine) -> bool: + return isinstance(column_type, sa.PickleType | sa.LargeBinary) @property def __node_attr_schemas(self) -> dict[str, AttrSchema]: - encoded_schemas = self._private_metadata.get(self._PRIVATE_SQL_NODE_SCHEMA_STORE_KEY, {}) - return {key: deserialize_attr_schema(encoded_schema, key=key) for key, encoded_schema in encoded_schemas.items()} + return self._attr_schemas_from_metadata( + table_class=self.Node, + metadata_key=self._PRIVATE_SQL_NODE_SCHEMA_STORE_KEY, + default_schemas=self._default_node_attr_schemas(), + preferred_order=[DEFAULT_ATTR_KEYS.T, DEFAULT_ATTR_KEYS.NODE_ID], + ) @__node_attr_schemas.setter def __node_attr_schemas(self, schemas: dict[str, AttrSchema]) -> None: + merged_schemas = self._default_node_attr_schemas() + merged_schemas.update(schemas) + schemas = merged_schemas encoded_schemas = {key: serialize_attr_schema(schema) for key, schema in schemas.items()} self._private_metadata[self._PRIVATE_SQL_NODE_SCHEMA_STORE_KEY] = encoded_schemas @property def __edge_attr_schemas(self) -> dict[str, AttrSchema]: - encoded_schemas = self._private_metadata.get(self._PRIVATE_SQL_EDGE_SCHEMA_STORE_KEY, {}) - return {key: deserialize_attr_schema(encoded_schema, key=key) for key, encoded_schema in encoded_schemas.items()} + return self._attr_schemas_from_metadata( + table_class=self.Edge, + metadata_key=self._PRIVATE_SQL_EDGE_SCHEMA_STORE_KEY, + default_schemas=self._default_edge_attr_schemas(), + preferred_order=[ + DEFAULT_ATTR_KEYS.EDGE_ID, + DEFAULT_ATTR_KEYS.EDGE_SOURCE, + DEFAULT_ATTR_KEYS.EDGE_TARGET, + ], + ) @__edge_attr_schemas.setter def __edge_attr_schemas(self, schemas: dict[str, AttrSchema]) -> None: + merged_schemas = self._default_edge_attr_schemas() + merged_schemas.update(schemas) + schemas = merged_schemas encoded_schemas = {key: serialize_attr_schema(schema) for key, schema in schemas.items()} self._private_metadata[self._PRIVATE_SQL_EDGE_SCHEMA_STORE_KEY] = encoded_schemas - def _init_schemas_from_tables(self) -> None: - """ - Initialize AttrSchema objects from existing database table columns. - This is used when loading an existing graph from the database. - """ - node_column_names = list(self.Node.__table__.columns.keys()) - preferred_node_order = [DEFAULT_ATTR_KEYS.T, DEFAULT_ATTR_KEYS.NODE_ID] - ordered_node_columns = [name for name in preferred_node_order if name in node_column_names] - ordered_node_columns.extend(name for name in node_column_names if name not in preferred_node_order) - - node_schemas = {k: v for k, v in self.__node_attr_schemas.items() if k in ordered_node_columns} - for column_name in ordered_node_columns: - if column_name in node_schemas: - continue - column = self.Node.__table__.columns[column_name] - node_schemas[column_name] = AttrSchema( - key=column_name, - dtype=sqlalchemy_type_to_polars_dtype(column.type), - ) - self.__node_attr_schemas = node_schemas - - # Initialize edge schemas from Edge table columns - edge_column_names = list(self.Edge.__table__.columns.keys()) - edge_schemas = {k: v for k, v in self.__edge_attr_schemas.items() if k in edge_column_names} - for column_name in self.Edge.__table__.columns.keys(): - if column_name in edge_schemas: - continue - column = self.Edge.__table__.columns[column_name] - edge_schemas[column_name] = AttrSchema( - key=column_name, - dtype=sqlalchemy_type_to_polars_dtype(column.type), - ) - self.__edge_attr_schemas = edge_schemas - def _restore_pickled_column_types(self, table: sa.Table) -> None: for column in table.columns: if isinstance(column.type, sa.LargeBinary): column.type = sa.PickleType() def _polars_schema_override(self, table_class: type[DeclarativeBase]) -> SchemaDict: - # Get the appropriate schema dict based on table class - if table_class.__tablename__ == self.Node.__tablename__: - schemas = self._node_attr_schemas() - else: - schemas = self._edge_attr_schemas() + schemas = self._attr_schemas_for_table(table_class) # Return schema overrides for columns safely represented in SQL. # Pickled columns are unpickled and casted in a second pass. @@ -627,21 +654,19 @@ def _polars_schema_override(self, table_class: type[DeclarativeBase]) -> SchemaD for key, schema in schemas.items() if ( key in table_class.__table__.columns - and not isinstance(table_class.__table__.columns[key].type, sa.PickleType | sa.LargeBinary) - and not (schema.dtype == pl.Object or isinstance(schema.dtype, pl.Array | pl.List)) + and not self._is_pickled_sql_type(table_class.__table__.columns[key].type) ) } def _cast_array_columns(self, table_class: type[DeclarativeBase], df: pl.DataFrame) -> pl.DataFrame: - # Get the appropriate schema dict based on table class - if table_class.__tablename__ == self.Node.__tablename__: - schemas = self._node_attr_schemas() - else: - schemas = self._edge_attr_schemas() + schemas = self._attr_schemas_for_table(table_class) casts: list[pl.Series] = [] for key, schema in schemas.items(): - if key not in df.columns: + if key not in df.columns or key not in table_class.__table__.columns: + continue + + if not self._is_pickled_sql_type(table_class.__table__.columns[key].type): continue try: From e76d8e5f82e1d3899f7abab94b17a91a6037aa82 Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Fri, 20 Feb 2026 08:43:16 +0900 Subject: [PATCH 13/30] initial try --- src/tracksdata/_test/test_attrs.py | 17 ++ src/tracksdata/attrs.py | 69 ++++++- src/tracksdata/graph/_rustworkx_graph.py | 21 ++ src/tracksdata/graph/_sql_graph.py | 185 +++++++++++++++++- .../graph/_test/test_graph_backends.py | 32 +++ src/tracksdata/utils/_dtypes.py | 5 + 6 files changed, 315 insertions(+), 14 deletions(-) diff --git a/src/tracksdata/_test/test_attrs.py b/src/tracksdata/_test/test_attrs.py index 8b29bf7d..481736fc 100644 --- a/src/tracksdata/_test/test_attrs.py +++ b/src/tracksdata/_test/test_attrs.py @@ -115,6 +115,23 @@ def test_attr_expr_method_delegation() -> None: assert result.to_list() == expected.to_list() +def test_attr_expr_struct_field_method_delegation() -> None: + df = pl.DataFrame({"s": [{"x": 1}, {"x": 2}, {"x": 3}]}, schema={"s": pl.Struct({"x": pl.Int64})}) + expr = NodeAttr("s").struct.field("x") + result = expr.evaluate(df) + assert isinstance(expr, NodeAttr) + assert result.to_list() == [1, 2, 3] + + +def test_attr_comparison_struct_field() -> None: + df = pl.DataFrame({"s": [{"x": 1}, {"x": 2}, {"x": 1}]}, schema={"s": pl.Struct({"x": pl.Int64})}) + comp = NodeAttr("s").struct.field("x") == 1 + result = comp.to_attr().evaluate(df) + assert comp.column == "s" + assert comp.field_path == ("x",) + assert result.to_list() == [True, False, True] + + def test_attr_expr_complex_expression() -> None: df = pl.DataFrame({"iou": [0.5, 0.7, 0.9], "distance": [10, 20, 30]}) expr = (1 - Attr("iou")) * Attr("distance") diff --git a/src/tracksdata/attrs.py b/src/tracksdata/attrs.py index 60f82db8..f5132d45 100644 --- a/src/tracksdata/attrs.py +++ b/src/tracksdata/attrs.py @@ -129,7 +129,8 @@ def __init__(self, attr: "Attr", op: Callable, other: ExprInput | MembershipExpr raise ValueError(f"Comparison operators are not supported for multiple columns. Found {columns}.") self.attr = attr - self.column = columns[0] + self.column = attr.root_column if attr.root_column is not None else columns[0] + self.field_path = attr.field_path self.op = op # casting numpy scalars to python scalars @@ -144,14 +145,18 @@ def __init__(self, attr: "Attr", op: Callable, other: ExprInput | MembershipExpr self.other = other def __repr__(self) -> str: - return f"{type(self.attr).__name__}({self.column}) {_OPS_MATH_SYMBOLS[self.op]} {self.other}" + if self.field_path: + column = ".".join([str(self.column), *self.field_path]) + else: + column = str(self.column) + return f"{type(self.attr).__name__}({column}) {_OPS_MATH_SYMBOLS[self.op]} {self.other}" def to_attr(self) -> "Attr": """ Transform the comparison back to an [Attr][tracksdata.attrs.Attr] object. This is useful for evaluating the expression on a DataFrame. """ - return Attr(self.op(pl.col(self.column), self.other)) + return Attr(self.op(self.attr.expr, self.other)) def __getattr__(self, attr: str) -> Any: return getattr(self.to_attr(), attr) @@ -198,6 +203,31 @@ def __ge__(self, other: ExprInput) -> "Attr": ... def __rge__(self, other: ExprInput) -> "Attr": ... +class _StructNamespace: + """Wrapper around polars struct namespace that preserves Attr semantics.""" + + def __init__(self, attr: "Attr") -> None: + self._attr = attr + self._namespace = attr.expr.struct + + def field(self, name: str) -> "Attr": + out = self._attr._wrap(self._namespace.field(name), preserve_field_path=True) + if isinstance(out, Attr): + out._append_field_path(name) + return out + + def __getattr__(self, name: str) -> Any: + namespace_attr = getattr(self._namespace, name) + if callable(namespace_attr): + + @functools.wraps(namespace_attr) + def _wrapped(*args, **kwargs): + return self._attr._wrap(namespace_attr(*args, **kwargs)) + + return _wrapped + return namespace_attr + + class Attr: """ A class to compose an attribute expression for attribute filtering or value evaluation. @@ -222,30 +252,40 @@ class Attr: def __init__(self, value: ExprInput) -> None: self._inf_exprs = [] # expressions multiplied by +inf self._neg_inf_exprs = [] # expressions multiplied by -inf + self._root_column: str | None = None + self._field_path: tuple[str, ...] = () if isinstance(value, str): self.expr = pl.col(value) + self._root_column = value elif isinstance(value, Attr): self.expr = value.expr # Copy infinity tracking from the other AttrExpr self._inf_exprs = value.inf_exprs self._neg_inf_exprs = value.neg_inf_exprs + self._root_column = value.root_column + self._field_path = value.field_path elif isinstance(value, AttrComparison): attr = value.to_attr() self.expr = attr.expr self._inf_exprs = attr.inf_exprs self._neg_inf_exprs = attr.neg_inf_exprs + self._root_column = attr.root_column + self._field_path = attr.field_path elif isinstance(value, Expr): self.expr = value else: self.expr = pl.lit(value) - def _wrap(self, expr: ExprInput) -> Union["Attr", Any]: + def _wrap(self, expr: ExprInput, *, preserve_field_path: bool = False) -> Union["Attr", Any]: if isinstance(expr, Expr): - result = Attr(expr) + result = type(self)(expr) # Propagate infinity tracking result._inf_exprs = self._inf_exprs.copy() result._neg_inf_exprs = self._neg_inf_exprs.copy() + if preserve_field_path: + result._root_column = self._root_column + result._field_path = self._field_path return result return expr @@ -377,6 +417,14 @@ def evaluate(self, df: DataFrame) -> Series: def columns(self) -> list[str]: return list(dict.fromkeys(self.expr_columns + self.inf_columns + self.neg_inf_columns)) + @property + def root_column(self) -> str | None: + return self._root_column + + @property + def field_path(self) -> tuple[str, ...]: + return self._field_path + @property def inf_exprs(self) -> list["Attr"]: """Get the expressions multiplied by positive infinity.""" @@ -464,6 +512,9 @@ def __getattr__(self, attr: str) -> Any: if attr.startswith("_"): raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'") + if attr == "struct": + return _StructNamespace(self) + # To auto generate operator methods such as `.log()`` expr_attr = getattr(self.expr, attr) if callable(expr_attr): @@ -475,6 +526,12 @@ def _wrapped(*args, **kwargs): return _wrapped return expr_attr + def _append_field_path(self, field_name: str) -> None: + if self._root_column is None: + self._field_path = () + else: + self._field_path = (*self._field_path, field_name) + def __repr__(self) -> str: return f"Attr({self.expr})" @@ -733,4 +790,4 @@ def polars_reduce_attr_comps( # Return True for all rows by using the first column as a reference raise ValueError("No attribute comparisons provided.") - return pl.reduce(reduce_op, [attr_comp.op(df[str(attr_comp.column)], attr_comp.other) for attr_comp in attr_comps]) + return pl.reduce(reduce_op, [attr_comp.op(attr_comp.attr.expr, attr_comp.other) for attr_comp in attr_comps]) diff --git a/src/tracksdata/graph/_rustworkx_graph.py b/src/tracksdata/graph/_rustworkx_graph.py index 05cf1c17..76c4cced 100644 --- a/src/tracksdata/graph/_rustworkx_graph.py +++ b/src/tracksdata/graph/_rustworkx_graph.py @@ -74,9 +74,30 @@ def _create_filter_func( ) -> Callable[[dict[str, Any]], bool]: LOG.info(f"Creating filter function for {attr_comps}") + def _extract_field_path(value: Any, field_path: tuple[str, ...]) -> Any: + for field in field_path: + if value is None: + return None + + if isinstance(value, dict): + value = value.get(field, None) + continue + + try: + value = value[field] + except (KeyError, IndexError, TypeError): + try: + value = getattr(value, field) + except AttributeError: + return None + + return value + def _filter(attrs: dict[str, Any]) -> bool: for attr_op in attr_comps: value = attrs.get(attr_op.column, schema[attr_op.column].default_value) + if attr_op.field_path: + value = _extract_field_path(value, attr_op.field_path) if not attr_op.op(value, attr_op.other): return False return True diff --git a/src/tracksdata/graph/_sql_graph.py b/src/tracksdata/graph/_sql_graph.py index c8ea38ed..30eefe4c 100644 --- a/src/tracksdata/graph/_sql_graph.py +++ b/src/tracksdata/graph/_sql_graph.py @@ -50,6 +50,54 @@ def _data_numpy_to_native(data: dict[str, Any]) -> None: data[k] = v.item() +def _field_scalar_sample(value: Any) -> Any: + if isinstance(value, list): + for item in value: + if item is not None: + if np.isscalar(item) and hasattr(item, "item"): + return item.item() + return item + return None + if np.isscalar(value) and hasattr(value, "item"): + return value.item() + return value + + +def _coerce_json_field_expr(lhs: Any, sample: Any) -> Any: + if isinstance(sample, bool): + if hasattr(lhs, "as_boolean"): + return lhs.as_boolean() + return sa.cast(lhs, sa.Boolean) + if isinstance(sample, int): + if hasattr(lhs, "as_integer"): + return lhs.as_integer() + return sa.cast(lhs, sa.BigInteger) + if isinstance(sample, float): + if hasattr(lhs, "as_float"): + return lhs.as_float() + return sa.cast(lhs, sa.Float) + if isinstance(sample, str): + if hasattr(lhs, "as_string"): + return lhs.as_string() + return sa.cast(lhs, sa.String) + return lhs + + +def _resolve_attr_filter_column( + table: type[DeclarativeBase], + attr_filter: AttrComparison, +) -> Any: + lhs = getattr(table, str(attr_filter.column)) + + if not attr_filter.field_path: + return lhs + + for field in attr_filter.field_path: + lhs = lhs[field] + + return _coerce_json_field_expr(lhs, _field_scalar_sample(attr_filter.other)) + + def _filter_query( query: sa.Select, table: type[DeclarativeBase], @@ -74,7 +122,10 @@ def _filter_query( """ LOG.info("Filter query:\n%s", attr_filters) query = query.filter( - *[attr_filter.op(getattr(table, str(attr_filter.column)), attr_filter.other) for attr_filter in attr_filters] + *[ + attr_filter.op(_resolve_attr_filter_column(table, attr_filter), attr_filter.other) + for attr_filter in attr_filters + ] ) return query @@ -369,6 +420,59 @@ def blob_default(engine: sa.Engine, value: bytes) -> sa.Text: raise RuntimeError(f"Unsupported dialect {engine.dialect.name}") +def _dtype_to_metadata_dict(dtype: pl.DataType) -> dict[str, Any]: + if isinstance(dtype, pl.Struct): + return { + "kind": "struct", + "fields": [ + {"name": field_name, "dtype": _dtype_to_metadata_dict(field_dtype)} + for field_name, field_dtype in dtype.to_schema().items() + ], + } + + if isinstance(dtype, pl.List): + return {"kind": "list", "inner": _dtype_to_metadata_dict(dtype.inner)} + + if isinstance(dtype, pl.Array): + return {"kind": "array", "inner": _dtype_to_metadata_dict(dtype.inner), "shape": list(dtype.shape)} + + return {"kind": "scalar", "name": dtype.base_type().__name__} + + +def _dtype_from_metadata_dict(serialized: Any) -> pl.DataType: + if not isinstance(serialized, dict): + return pl.Object + + kind = serialized.get("kind") + + if kind == "struct": + fields: dict[str, pl.DataType] = {} + for field in serialized.get("fields", []): + if not isinstance(field, dict) or "name" not in field: + continue + fields[str(field["name"])] = _dtype_from_metadata_dict(field.get("dtype")) + return pl.Struct(fields) + + if kind == "list": + return pl.List(_dtype_from_metadata_dict(serialized.get("inner"))) + + if kind == "array": + inner = _dtype_from_metadata_dict(serialized.get("inner")) + shape_raw = serialized.get("shape", []) + if not isinstance(shape_raw, list | tuple) or len(shape_raw) == 0: + return pl.List(inner) + return pl.Array(inner, shape=tuple(shape_raw)) + + if kind == "scalar": + dtype_name = serialized.get("name") + dtype = getattr(pl, dtype_name, None) + if dtype is None: + return pl.Object + return dtype + + return pl.Object + + class SQLGraph(BaseGraph): """ SQL-based graph implementation using SQLAlchemy ORM. @@ -441,6 +545,7 @@ class SQLGraph(BaseGraph): """ node_id_time_multiplier: int = 1_000_000_000 + _STRUCT_DTYPE_METADATA_KEY: str = f"{BaseGraph._PRIVATE_METADATA_PREFIX}struct_attr_dtypes" Base: type[DeclarativeBase] Node: type[DeclarativeBase] Edge: type[DeclarativeBase] @@ -479,6 +584,7 @@ def __init__( # Initialize schemas from existing table columns self._init_schemas_from_tables() + self._restore_struct_attr_dtypes_from_private_metadata() self._max_id_per_time = {} self._update_max_id_per_time() @@ -581,6 +687,45 @@ def _init_schemas_from_tables(self) -> None: dtype=pl_dtype, ) + def _struct_attr_dtypes_metadata(self) -> dict[str, dict[str, Any]]: + metadata = self._private_metadata.get(self._STRUCT_DTYPE_METADATA_KEY, None) + if not isinstance(metadata, dict): + return {"node": {}, "edge": {}} + + node_dtypes = metadata.get("node") + edge_dtypes = metadata.get("edge") + if not isinstance(node_dtypes, dict): + node_dtypes = {} + if not isinstance(edge_dtypes, dict): + edge_dtypes = {} + return {"node": node_dtypes, "edge": edge_dtypes} + + def _set_struct_attr_dtypes_metadata(self, metadata: dict[str, dict[str, Any]]) -> None: + self._private_metadata.update(**{self._STRUCT_DTYPE_METADATA_KEY: metadata}) + + def _register_struct_attr_dtype(self, *, table: str, key: str, dtype: pl.DataType) -> None: + metadata = self._struct_attr_dtypes_metadata() + metadata.setdefault(table, {})[key] = _dtype_to_metadata_dict(dtype) + self._set_struct_attr_dtypes_metadata(metadata) + + def _remove_struct_attr_dtype(self, *, table: str, key: str) -> None: + metadata = self._struct_attr_dtypes_metadata() + table_mapping = metadata.get(table, {}) + if key in table_mapping: + table_mapping.pop(key, None) + self._set_struct_attr_dtypes_metadata(metadata) + + def _restore_struct_attr_dtypes_from_private_metadata(self) -> None: + metadata = self._struct_attr_dtypes_metadata() + + for key, serialized in metadata["node"].items(): + if key in self.__node_attr_schemas: + self.__node_attr_schemas[key].dtype = _dtype_from_metadata_dict(serialized) + + for key, serialized in metadata["edge"].items(): + if key in self.__edge_attr_schemas: + self.__edge_attr_schemas[key].dtype = _dtype_from_metadata_dict(serialized) + def _restore_pickled_column_types(self, table: sa.Table) -> None: for column in table.columns: if isinstance(column.type, sa.LargeBinary): @@ -597,7 +742,7 @@ def _polars_schema_override(self, table_class: type[DeclarativeBase]) -> SchemaD return { key: schema.dtype for key, schema in schemas.items() - if not (schema.dtype == pl.Object or isinstance(schema.dtype, pl.Array | pl.List)) + if not (schema.dtype == pl.Object or isinstance(schema.dtype, pl.Array | pl.List | pl.Struct)) } def _cast_array_columns(self, table_class: type[DeclarativeBase], df: pl.DataFrame) -> pl.DataFrame: @@ -607,12 +752,28 @@ def _cast_array_columns(self, table_class: type[DeclarativeBase], df: pl.DataFra else: schemas = self._edge_attr_schemas() - # Cast array columns (stored as blobs in database) - df = df.with_columns( - pl.Series(key, df[key].to_list(), dtype=schema.dtype) - for key, schema in schemas.items() - if isinstance(schema.dtype, pl.Array) and key in df.columns - ) + for key, schema in schemas.items(): + if key not in df.columns: + continue + + if isinstance(schema.dtype, pl.Array): + # Array columns are stored as pickled blobs. + df = df.with_columns(pl.Series(key, df[key].to_list(), dtype=schema.dtype)) + continue + + if isinstance(schema.dtype, pl.Struct): + source_dtype = df.schema[key] + if source_dtype == pl.String: + # SQLite returns JSON columns as strings. + df = df.with_columns( + pl.when(pl.col(key).is_null()) + .then(None) + .otherwise(pl.col(key).str.json_decode(schema.dtype)) + .alias(key) + ) + elif source_dtype != schema.dtype: + df = df.with_columns(pl.Series(key, df[key].to_list(), dtype=schema.dtype)) + return df def _update_max_id_per_time(self) -> None: @@ -1626,6 +1787,9 @@ def add_node_attr_key( # Add column to database self._add_new_column(self.Node, schema) + if isinstance(schema.dtype, pl.Struct): + self._register_struct_attr_dtype(table="node", key=schema.key, dtype=schema.dtype) + def remove_node_attr_key(self, key: str) -> None: if key not in self.node_attr_keys(): raise ValueError(f"Node attribute key {key} does not exist") @@ -1635,6 +1799,7 @@ def remove_node_attr_key(self, key: str) -> None: self._drop_column(self.Node, key) self.__node_attr_schemas.pop(key, None) + self._remove_struct_attr_dtype(table="node", key=key) def add_edge_attr_key( self, @@ -1651,12 +1816,16 @@ def add_edge_attr_key( # Add column to database self._add_new_column(self.Edge, schema) + if isinstance(schema.dtype, pl.Struct): + self._register_struct_attr_dtype(table="edge", key=schema.key, dtype=schema.dtype) + def remove_edge_attr_key(self, key: str) -> None: if key not in self.edge_attr_keys(): raise ValueError(f"Edge attribute key {key} does not exist") self._drop_column(self.Edge, key) self.__edge_attr_schemas.pop(key, None) + self._remove_struct_attr_dtype(table="edge", key=key) def num_edges(self) -> int: with Session(self._engine) as session: diff --git a/src/tracksdata/graph/_test/test_graph_backends.py b/src/tracksdata/graph/_test/test_graph_backends.py index e9088c75..49290c97 100644 --- a/src/tracksdata/graph/_test/test_graph_backends.py +++ b/src/tracksdata/graph/_test/test_graph_backends.py @@ -224,6 +224,20 @@ def test_filter_nodes_by_membership(graph_backend: BaseGraph) -> None: assert set(np_members) == {node_b} +def test_filter_nodes_by_struct_field(graph_backend: BaseGraph) -> None: + graph_backend.add_node_attr_key("measurements", pl.Struct({"score": pl.Int64, "name": pl.String})) + + node_a = graph_backend.add_node({"t": 0, "measurements": {"score": 1, "name": "A"}}) + node_b = graph_backend.add_node({"t": 1, "measurements": {"score": 2, "name": "B"}}) + node_c = graph_backend.add_node({"t": 2, "measurements": {"score": 1, "name": "C"}}) + + score_nodes = graph_backend.filter(NodeAttr("measurements").struct.field("score") == 1).node_ids() + assert set(score_nodes) == {node_a, node_c} + + name_nodes = graph_backend.filter(NodeAttr("measurements").struct.field("name") == "B").node_ids() + assert set(name_nodes) == {node_b} + + def test_time_points(graph_backend: BaseGraph) -> None: """Test retrieving time points.""" graph_backend.add_node({"t": 0}) @@ -1603,6 +1617,24 @@ def test_sql_graph_mask_update_survives_reload(tmp_path: Path) -> None: np.testing.assert_array_equal(stored_mask.mask, mask_data) +def test_sql_graph_struct_dtype_survives_reload(tmp_path: Path) -> None: + db_path = tmp_path / "struct_graph.db" + graph = SQLGraph("sqlite", str(db_path)) + graph.add_node_attr_key("measurements", pl.Struct({"score": pl.Int64, "label": pl.String})) + + node_id = graph.add_node({"t": 0, "measurements": {"score": 7, "label": "A"}}) + graph._engine.dispose() + + reloaded = SQLGraph("sqlite", str(db_path)) + + df = reloaded.node_attrs(attr_keys=["measurements"]) + assert df.schema["measurements"] == pl.Struct({"score": pl.Int64, "label": pl.String}) + assert df["measurements"].to_list() == [{"score": 7, "label": "A"}] + + ids = reloaded.filter(NodeAttr("measurements").struct.field("score") == 7).node_ids() + assert ids == [node_id] + + def test_sql_graph_max_id_restored_per_timepoint(tmp_path: Path) -> None: """Reloading a SQLGraph should respect existing max IDs per time point.""" db_path = tmp_path / "id_restore.db" diff --git a/src/tracksdata/utils/_dtypes.py b/src/tracksdata/utils/_dtypes.py index 8e671487..85eacff0 100644 --- a/src/tracksdata/utils/_dtypes.py +++ b/src/tracksdata/utils/_dtypes.py @@ -383,6 +383,10 @@ def polars_dtype_to_sqlalchemy_type(dtype: pl.DataType) -> TypeEngine: >>> polars_dtype_to_sqlalchemy_type(pl.Boolean) """ + # Handle struct types as JSON for backend-level field filtering. + if isinstance(dtype, pl.Struct): + return sa.JSON() + # Handle sequence types - use PickleType for storage if isinstance(dtype, pl.Array | pl.List): return sa.PickleType() @@ -407,6 +411,7 @@ def polars_dtype_to_sqlalchemy_type(dtype: pl.DataType) -> TypeEngine: (sa.Float, pl.Float64), (sa.Text, pl.String), # Must come before String (sa.String, pl.String), + (sa.JSON, pl.Object), (sa.PickleType, pl.Object), # Must come before LargeBinary (sa.LargeBinary, pl.Object), ] From 7bec3697ee826f1b7790750766f4d22e47676b35 Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Fri, 20 Feb 2026 13:50:17 +0900 Subject: [PATCH 14/30] saving private metadata --- src/tracksdata/graph/_base_graph.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/tracksdata/graph/_base_graph.py b/src/tracksdata/graph/_base_graph.py index 03dc3a01..fbae168f 100644 --- a/src/tracksdata/graph/_base_graph.py +++ b/src/tracksdata/graph/_base_graph.py @@ -1852,6 +1852,7 @@ def to_geff( } td_metadata = self.metadata.copy() + td_metadata.update(self._private_metadata) td_metadata.pop("geff", None) # avoid geff being written multiple times geff_metadata = geff.GeffMetadata( From 852f717989208652f5d4b0766b6e0fa73f47b9bd Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Thu, 26 Feb 2026 11:23:29 +0900 Subject: [PATCH 15/30] rustworkx reviewed --- .codex/environments/environment.toml | 10 +++++++++ src/tracksdata/_test/test_attrs.py | 2 +- src/tracksdata/attrs.py | 27 +++++++++++++++++++++--- src/tracksdata/graph/_rustworkx_graph.py | 4 ++-- src/tracksdata/graph/_sql_graph.py | 4 ++-- 5 files changed, 39 insertions(+), 8 deletions(-) create mode 100644 .codex/environments/environment.toml diff --git a/.codex/environments/environment.toml b/.codex/environments/environment.toml new file mode 100644 index 00000000..1324ca94 --- /dev/null +++ b/.codex/environments/environment.toml @@ -0,0 +1,10 @@ +# THIS IS AUTOGENERATED. DO NOT EDIT MANUALLY +version = 1 +name = "tracksdata" + +[setup] +script = ''' +uv venv +uv pip install -e .[spatial,test,docs] +source .venv/bin/activate +''' diff --git a/src/tracksdata/_test/test_attrs.py b/src/tracksdata/_test/test_attrs.py index 481736fc..cc403723 100644 --- a/src/tracksdata/_test/test_attrs.py +++ b/src/tracksdata/_test/test_attrs.py @@ -128,7 +128,7 @@ def test_attr_comparison_struct_field() -> None: comp = NodeAttr("s").struct.field("x") == 1 result = comp.to_attr().evaluate(df) assert comp.column == "s" - assert comp.field_path == ("x",) + assert comp.attr.field_path == ("x",) assert result.to_list() == [True, False, True] diff --git a/src/tracksdata/attrs.py b/src/tracksdata/attrs.py index f5132d45..be223211 100644 --- a/src/tracksdata/attrs.py +++ b/src/tracksdata/attrs.py @@ -130,7 +130,6 @@ def __init__(self, attr: "Attr", op: Callable, other: ExprInput | MembershipExpr self.attr = attr self.column = attr.root_column if attr.root_column is not None else columns[0] - self.field_path = attr.field_path self.op = op # casting numpy scalars to python scalars @@ -145,8 +144,8 @@ def __init__(self, attr: "Attr", op: Callable, other: ExprInput | MembershipExpr self.other = other def __repr__(self) -> str: - if self.field_path: - column = ".".join([str(self.column), *self.field_path]) + if self.attr.field_path: + column = ".".join([str(self.column), *self.attr.field_path]) else: column = str(self.column) return f"{type(self.attr).__name__}({column}) {_OPS_MATH_SYMBOLS[self.op]} {self.other}" @@ -252,6 +251,9 @@ class Attr: def __init__(self, value: ExprInput) -> None: self._inf_exprs = [] # expressions multiplied by +inf self._neg_inf_exprs = [] # expressions multiplied by -inf + # Path-tracking for backend filters: + # - root_column: top-level column used to store the value. + # - field_path: nested struct path from that root column. self._root_column: str | None = None self._field_path: tuple[str, ...] = () @@ -419,10 +421,29 @@ def columns(self) -> list[str]: @property def root_column(self) -> str | None: + """ + Top-level column name from which this expression originates. + + Examples + -------- + `Attr("t").root_column == "t"` + `NodeAttr("measurements").struct.field("score").root_column == "measurements"` + """ return self._root_column @property def field_path(self) -> tuple[str, ...]: + """ + Nested struct-field path relative to [root_column][tracksdata.attrs.Attr.root_column]. + + Empty tuple means no nested access. + + Examples + -------- + `Attr("t").field_path == ()` + `NodeAttr("measurements").struct.field("score").field_path == ("score",)` + `NodeAttr("meta").struct.field("det").struct.field("conf").field_path == ("det", "conf")` + """ return self._field_path @property diff --git a/src/tracksdata/graph/_rustworkx_graph.py b/src/tracksdata/graph/_rustworkx_graph.py index 76c4cced..d5aaa154 100644 --- a/src/tracksdata/graph/_rustworkx_graph.py +++ b/src/tracksdata/graph/_rustworkx_graph.py @@ -96,8 +96,8 @@ def _extract_field_path(value: Any, field_path: tuple[str, ...]) -> Any: def _filter(attrs: dict[str, Any]) -> bool: for attr_op in attr_comps: value = attrs.get(attr_op.column, schema[attr_op.column].default_value) - if attr_op.field_path: - value = _extract_field_path(value, attr_op.field_path) + if attr_op.attr.field_path: + value = _extract_field_path(value, attr_op.attr.field_path) if not attr_op.op(value, attr_op.other): return False return True diff --git a/src/tracksdata/graph/_sql_graph.py b/src/tracksdata/graph/_sql_graph.py index 30eefe4c..2227e602 100644 --- a/src/tracksdata/graph/_sql_graph.py +++ b/src/tracksdata/graph/_sql_graph.py @@ -89,10 +89,10 @@ def _resolve_attr_filter_column( ) -> Any: lhs = getattr(table, str(attr_filter.column)) - if not attr_filter.field_path: + if not attr_filter.attr.field_path: return lhs - for field in attr_filter.field_path: + for field in attr_filter.attr.field_path: lhs = lhs[field] return _coerce_json_field_expr(lhs, _field_scalar_sample(attr_filter.other)) From 4af9904f0f32d74047ab1d81ee2ec61904144127 Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Thu, 26 Feb 2026 11:51:27 +0900 Subject: [PATCH 16/30] working with clean code? --- src/tracksdata/graph/_sql_graph.py | 150 +++++++++++++++++++++++++++-- 1 file changed, 144 insertions(+), 6 deletions(-) diff --git a/src/tracksdata/graph/_sql_graph.py b/src/tracksdata/graph/_sql_graph.py index 65784572..dbcee51b 100644 --- a/src/tracksdata/graph/_sql_graph.py +++ b/src/tracksdata/graph/_sql_graph.py @@ -52,10 +52,100 @@ def _data_numpy_to_native(data: dict[str, Any]) -> None: data[k] = v.item() +def _coerce_json_field_expr(lhs: Any, dtype: pl.DataType | None) -> Any: + if dtype is None: + return lhs + + dtype_base = dtype.base_type() + + if dtype_base == pl.Boolean: + if hasattr(lhs, "as_boolean"): + return lhs.as_boolean() + return sa.cast(lhs, sa.Boolean) + if dtype_base in {pl.Int8, pl.Int16, pl.Int32, pl.Int64, pl.UInt8, pl.UInt16, pl.UInt32, pl.UInt64}: + if hasattr(lhs, "as_integer"): + return lhs.as_integer() + return sa.cast(lhs, sa.BigInteger) + if dtype_base in {pl.Float16, pl.Float32, pl.Float64}: + if hasattr(lhs, "as_float"): + return lhs.as_float() + return sa.cast(lhs, sa.Float) + if dtype_base in {pl.String, pl.Utf8}: + if hasattr(lhs, "as_string"): + return lhs.as_string() + return sa.cast(lhs, sa.String) + return lhs + + +def _field_dtype_from_schema( + attr_filter: AttrComparison, + attr_schemas: dict[str, AttrSchema] | None, +) -> pl.DataType | None: + if attr_schemas is None: + return None + + schema = attr_schemas.get(str(attr_filter.column)) + if schema is None: + return None + + dtype = schema.dtype + for field in attr_filter.attr.field_path: + if not isinstance(dtype, pl.Struct): + return None + + dtype = dtype.to_schema().get(field) + if dtype is None: + return None + + return dtype + + +def _resolve_attr_filter_column( + table: type[DeclarativeBase], + attr_filter: AttrComparison, + attr_schemas: dict[str, AttrSchema] | None = None, +) -> Any: + lhs = getattr(table, str(attr_filter.column)) + + if not attr_filter.attr.field_path: + return lhs + + for field in attr_filter.attr.field_path: + lhs = lhs[field] + + field_dtype = _field_dtype_from_schema(attr_filter, attr_schemas) + return _coerce_json_field_expr(lhs, field_dtype) + + +def _json_decode_safe_dtype(dtype: pl.DataType) -> pl.DataType: + """ + Return a JSON-decodable dtype by replacing fixed-size arrays with lists recursively. + """ + if isinstance(dtype, pl.Array): + return pl.List(_json_decode_safe_dtype(dtype.inner)) + + if isinstance(dtype, pl.List): + return pl.List(_json_decode_safe_dtype(dtype.inner)) + + if isinstance(dtype, pl.Struct): + return pl.Struct({key: _json_decode_safe_dtype(inner) for key, inner in dtype.to_schema().items()}) + + return dtype + + +def _struct_json_decode_expr(column: str, target_dtype: pl.Struct) -> pl.Expr: + decode_dtype = _json_decode_safe_dtype(target_dtype) + decoded_expr = pl.when(pl.col(column).is_null()).then(None).otherwise(pl.col(column).str.json_decode(decode_dtype)) + if decode_dtype != target_dtype: + decoded_expr = decoded_expr.cast(target_dtype) + return decoded_expr.alias(column) + + def _filter_query( query: sa.Select, table: type[DeclarativeBase], attr_filters: list[AttrComparison], + attr_schemas: dict[str, AttrSchema] | None = None, ) -> sa.Select: """ Filter a query by a list of attribute filters. @@ -68,6 +158,8 @@ def _filter_query( The table to filter. attr_filters : list[AttrComparison] The attribute filters to apply. + attr_schemas : dict[str, AttrSchema] | None, optional + Attribute schema map used to resolve nested struct field dtypes. Returns ------- @@ -76,7 +168,13 @@ def _filter_query( """ LOG.info("Filter query:\n%s", attr_filters) query = query.filter( - *[attr_filter.op(getattr(table, str(attr_filter.column)), attr_filter.other) for attr_filter in attr_filters] + *[ + attr_filter.op( + _resolve_attr_filter_column(table, attr_filter, attr_schemas=attr_schemas), + attr_filter.other, + ) + for attr_filter in attr_filters + ] ) return query @@ -100,6 +198,8 @@ def __init__( self._node_query: sa.Select = sa.select(self._graph.Node) self._edge_query: sa.Select = sa.select(self._graph.Edge) node_filtered = False + node_attr_schemas = self._graph._node_attr_schemas() + edge_attr_schemas = self._graph._edge_attr_schemas() if node_ids is not None: if hasattr(node_ids, "tolist"): @@ -119,7 +219,12 @@ def __init__( if self._node_attr_comps: node_filtered = True # filtering nodes by attributes - self._node_query = _filter_query(self._node_query, self._graph.Node, self._node_attr_comps) + self._node_query = _filter_query( + self._node_query, + self._graph.Node, + self._node_attr_comps, + attr_schemas=node_attr_schemas, + ) # if both node and edge attributes are filtered # we need to select subset of edges that belong to the filtered nodes @@ -135,17 +240,32 @@ def __init__( SourceNode, self._graph.Edge.source_id == SourceNode.node_id, ) - self._edge_query = _filter_query(self._edge_query, SourceNode, self._node_attr_comps) + self._edge_query = _filter_query( + self._edge_query, + SourceNode, + self._node_attr_comps, + attr_schemas=node_attr_schemas, + ) if self._include_sources or include_none: self._edge_query = self._edge_query.join( TargetNode, self._graph.Edge.target_id == TargetNode.node_id, ) - self._edge_query = _filter_query(self._edge_query, TargetNode, self._node_attr_comps) + self._edge_query = _filter_query( + self._edge_query, + TargetNode, + self._node_attr_comps, + attr_schemas=node_attr_schemas, + ) if self._edge_attr_comps: - self._edge_query = _filter_query(self._edge_query, self._graph.Edge, self._edge_attr_comps) + self._edge_query = _filter_query( + self._edge_query, + self._graph.Edge, + self._edge_attr_comps, + attr_schemas=edge_attr_schemas, + ) # we haven't filtered the nodes by attributes # so we only return the nodes that are in the edges @@ -601,6 +721,10 @@ def _attr_schemas_for_table(self, table_class: type[DeclarativeBase]) -> dict[st def _is_pickled_sql_type(column_type: TypeEngine) -> bool: return isinstance(column_type, sa.PickleType | sa.LargeBinary) + @staticmethod + def _is_json_sql_type(column_type: TypeEngine) -> bool: + return isinstance(column_type, sa.JSON) + @property def __node_attr_schemas(self) -> dict[str, AttrSchema]: return self._attr_schemas_from_metadata( @@ -655,18 +779,30 @@ def _polars_schema_override(self, table_class: type[DeclarativeBase]) -> SchemaD if ( key in table_class.__table__.columns and not self._is_pickled_sql_type(table_class.__table__.columns[key].type) + and not isinstance(schema.dtype, pl.Struct) ) } def _cast_array_columns(self, table_class: type[DeclarativeBase], df: pl.DataFrame) -> pl.DataFrame: schemas = self._attr_schemas_for_table(table_class) + decode_exprs: list[pl.Expr] = [] casts: list[pl.Series] = [] for key, schema in schemas.items(): if key not in df.columns or key not in table_class.__table__.columns: continue - if not self._is_pickled_sql_type(table_class.__table__.columns[key].type): + column_type = table_class.__table__.columns[key].type + source_dtype = df.schema[key] + + if isinstance(schema.dtype, pl.Struct) and self._is_json_sql_type(column_type): + if source_dtype == pl.String: + decode_exprs.append(_struct_json_decode_expr(key, schema.dtype)) + elif source_dtype != schema.dtype: + casts.append(pl.Series(key, df[key].to_list(), dtype=schema.dtype)) + continue + + if not self._is_pickled_sql_type(column_type): continue try: @@ -675,6 +811,8 @@ def _cast_array_columns(self, table_class: type[DeclarativeBase], df: pl.DataFra # Keep original dtype when values cannot be casted to the target schema. continue + if decode_exprs: + df = df.with_columns(decode_exprs) if casts: df = df.with_columns(casts) return df From 6c69e76813bc42daed20d2a73b910b9569226f28 Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Fri, 10 Apr 2026 05:46:54 +0200 Subject: [PATCH 17/30] updated impl --- src/tracksdata/graph/_sql_graph.py | 340 ++++++++++++++++------------- src/tracksdata/utils/_dtypes.py | 70 +++++- 2 files changed, 256 insertions(+), 154 deletions(-) diff --git a/src/tracksdata/graph/_sql_graph.py b/src/tracksdata/graph/_sql_graph.py index 6a287edd..9dd4b570 100644 --- a/src/tracksdata/graph/_sql_graph.py +++ b/src/tracksdata/graph/_sql_graph.py @@ -20,8 +20,11 @@ from tracksdata.utils._cache import cache_method from tracksdata.utils._dataframe import unpack_array_attrs, unpickle_bytes_columns from tracksdata.utils._dtypes import ( + STRUCT_FIELD_SEP, AttrSchema, deserialize_attr_schema, + flatten_struct_dtype, + flatten_struct_value, polars_dtype_to_sqlalchemy_type, process_attr_key_args, serialize_attr_schema, @@ -56,100 +59,27 @@ def _data_numpy_to_native(data: dict[str, Any]) -> None: data[k] = v.item() -def _coerce_json_field_expr(lhs: Any, dtype: pl.DataType | None) -> Any: - if dtype is None: - return lhs - - dtype_base = dtype.base_type() - - if dtype_base == pl.Boolean: - if hasattr(lhs, "as_boolean"): - return lhs.as_boolean() - return sa.cast(lhs, sa.Boolean) - if dtype_base in {pl.Int8, pl.Int16, pl.Int32, pl.Int64, pl.UInt8, pl.UInt16, pl.UInt32, pl.UInt64}: - if hasattr(lhs, "as_integer"): - return lhs.as_integer() - return sa.cast(lhs, sa.BigInteger) - if dtype_base in {pl.Float16, pl.Float32, pl.Float64}: - if hasattr(lhs, "as_float"): - return lhs.as_float() - return sa.cast(lhs, sa.Float) - if dtype_base in {pl.String, pl.Utf8}: - if hasattr(lhs, "as_string"): - return lhs.as_string() - return sa.cast(lhs, sa.String) - return lhs - - -def _field_dtype_from_schema( - attr_filter: AttrComparison, - attr_schemas: dict[str, AttrSchema] | None, -) -> pl.DataType | None: - if attr_schemas is None: - return None - - schema = attr_schemas.get(str(attr_filter.column)) - if schema is None: - return None - - dtype = schema.dtype - for field in attr_filter.attr.field_path: - if not isinstance(dtype, pl.Struct): - return None - - dtype = dtype.to_schema().get(field) - if dtype is None: - return None - - return dtype - - def _resolve_attr_filter_column( table: type[DeclarativeBase], attr_filter: AttrComparison, - attr_schemas: dict[str, AttrSchema] | None = None, ) -> Any: - lhs = getattr(table, str(attr_filter.column)) - - if not attr_filter.attr.field_path: - return lhs + """Return the SQLAlchemy column expression for an AttrComparison. - for field in attr_filter.attr.field_path: - lhs = lhs[field] - - field_dtype = _field_dtype_from_schema(attr_filter, attr_schemas) - return _coerce_json_field_expr(lhs, field_dtype) - - -def _json_decode_safe_dtype(dtype: pl.DataType) -> pl.DataType: - """ - Return a JSON-decodable dtype by replacing fixed-size arrays with lists recursively. + For struct field paths (e.g. ``NodeAttr("m").struct.field("score")``), the + field path is joined with ``STRUCT_FIELD_SEP`` to form the physical flat + column name (e.g. ``m__score``), which is a native SQL column. """ - if isinstance(dtype, pl.Array): - return pl.List(_json_decode_safe_dtype(dtype.inner)) - - if isinstance(dtype, pl.List): - return pl.List(_json_decode_safe_dtype(dtype.inner)) - - if isinstance(dtype, pl.Struct): - return pl.Struct({key: _json_decode_safe_dtype(inner) for key, inner in dtype.to_schema().items()}) - - return dtype - + if not attr_filter.attr.field_path: + return getattr(table, str(attr_filter.column)) -def _struct_json_decode_expr(column: str, target_dtype: pl.Struct) -> pl.Expr: - decode_dtype = _json_decode_safe_dtype(target_dtype) - decoded_expr = pl.when(pl.col(column).is_null()).then(None).otherwise(pl.col(column).str.json_decode(decode_dtype)) - if decode_dtype != target_dtype: - decoded_expr = decoded_expr.cast(target_dtype) - return decoded_expr.alias(column) + flat_col = STRUCT_FIELD_SEP.join([str(attr_filter.column), *attr_filter.attr.field_path]) + return getattr(table, flat_col) def _filter_query( query: sa.Select, table: type[DeclarativeBase], attr_filters: list[AttrComparison], - attr_schemas: dict[str, AttrSchema] | None = None, ) -> sa.Select: """ Filter a query by a list of attribute filters. @@ -162,8 +92,6 @@ def _filter_query( The table to filter. attr_filters : list[AttrComparison] The attribute filters to apply. - attr_schemas : dict[str, AttrSchema] | None, optional - Attribute schema map used to resolve nested struct field dtypes. Returns ------- @@ -174,7 +102,7 @@ def _filter_query( query = query.filter( *[ attr_filter.op( - _resolve_attr_filter_column(table, attr_filter, attr_schemas=attr_schemas), + _resolve_attr_filter_column(table, attr_filter), attr_filter.other, ) for attr_filter in attr_filters @@ -202,8 +130,6 @@ def __init__( self._node_query: sa.Select = sa.select(self._graph.Node) self._edge_query: sa.Select = sa.select(self._graph.Edge) node_filtered = False - node_attr_schemas = self._graph._node_attr_schemas() - edge_attr_schemas = self._graph._edge_attr_schemas() if node_ids is not None: if hasattr(node_ids, "tolist"): @@ -227,7 +153,6 @@ def __init__( self._node_query, self._graph.Node, self._node_attr_comps, - attr_schemas=node_attr_schemas, ) # if both node and edge attributes are filtered @@ -248,7 +173,6 @@ def __init__( self._edge_query, SourceNode, self._node_attr_comps, - attr_schemas=node_attr_schemas, ) if self._include_sources or include_none: @@ -260,7 +184,6 @@ def __init__( self._edge_query, TargetNode, self._node_attr_comps, - attr_schemas=node_attr_schemas, ) if self._edge_attr_comps: @@ -268,7 +191,6 @@ def __init__( self._edge_query, self._graph.Edge, self._edge_attr_comps, - attr_schemas=edge_attr_schemas, ) # we haven't filtered the nodes by attributes @@ -703,16 +625,22 @@ def _attr_schemas_from_metadata( {key: deserialize_attr_schema(encoded_schema, key=key) for key, encoded_schema in encoded_schemas.items()} ) + # Compute the set of flat physical columns that belong to known struct schemas, + # so the legacy fallback below does not register them as independent logical keys. + known_flat_cols: set[str] = set() + for schema in schemas.values(): + if isinstance(schema.dtype, pl.Struct): + known_flat_cols.update(fc for fc, _ in flatten_struct_dtype(schema.key, schema.dtype)) + # Legacy databases may not have schema metadata for all columns. for column_name, column in table_class.__table__.columns.items(): - if column_name not in schemas: + if column_name not in schemas and column_name not in known_flat_cols: schemas[column_name] = AttrSchema( key=column_name, dtype=sqlalchemy_type_to_polars_dtype(column.type), ) ordered_keys = [key for key in preferred_order if key in schemas] - ordered_keys.extend(key for key in table_class.__table__.columns.keys() if key not in ordered_keys) ordered_keys.extend(key for key in schemas if key not in ordered_keys) return {key: schemas[key] for key in ordered_keys} @@ -725,10 +653,6 @@ def _attr_schemas_for_table(self, table_class: type[DeclarativeBase]) -> dict[st def _is_pickled_sql_type(column_type: TypeEngine) -> bool: return isinstance(column_type, sa.PickleType | sa.LargeBinary) - @staticmethod - def _is_json_sql_type(column_type: TypeEngine) -> bool: - return isinstance(column_type, sa.JSON) - @property def __node_attr_schemas(self) -> dict[str, AttrSchema]: return self._attr_schemas_from_metadata( @@ -773,52 +697,92 @@ def _restore_pickled_column_types(self, table: sa.Table) -> None: column.type = sa.PickleType() def _polars_schema_override(self, table_class: type[DeclarativeBase]) -> SchemaDict: + """Return polars dtype overrides for physical columns in *table_class*. + + Flat struct leaf columns are included with their native leaf dtypes. + Pickled columns are excluded here and handled in a second pass by + ``_cast_array_columns``. + """ + overrides: SchemaDict = {} schemas = self._attr_schemas_for_table(table_class) + table_cols = table_class.__table__.columns - # Return schema overrides for columns safely represented in SQL. - # Pickled columns are unpickled and casted in a second pass. - return { - key: schema.dtype - for key, schema in schemas.items() - if ( - key in table_class.__table__.columns - and not self._is_pickled_sql_type(table_class.__table__.columns[key].type) - and not isinstance(schema.dtype, pl.Struct) - ) - } + for key, schema in schemas.items(): + if isinstance(schema.dtype, pl.Struct): + # Emit overrides for each leaf physical column. + for flat_col, leaf_dtype in flatten_struct_dtype(key, schema.dtype): + if flat_col in table_cols and not self._is_pickled_sql_type(table_cols[flat_col].type): + overrides[flat_col] = leaf_dtype + elif key in table_cols and not self._is_pickled_sql_type(table_cols[key].type): + overrides[key] = schema.dtype + + return overrides + + @staticmethod + def _build_struct_expr(key: str, dtype: pl.Struct) -> pl.Expr: + """Recursively build a ``pl.struct`` expression from flat leaf columns.""" + fields: list[pl.Expr] = [] + for field_name, field_dtype in dtype.to_schema().items(): + flat_col = f"{key}{STRUCT_FIELD_SEP}{field_name}" + if isinstance(field_dtype, pl.Struct): + fields.append(SQLGraph._build_struct_expr(flat_col, field_dtype).alias(field_name)) + else: + fields.append(pl.col(flat_col).alias(field_name)) + return pl.struct(fields) def _cast_array_columns(self, table_class: type[DeclarativeBase], df: pl.DataFrame) -> pl.DataFrame: + """Cast pickled columns to their target dtype and reconstruct struct columns.""" schemas = self._attr_schemas_for_table(table_class) + table_cols = table_class.__table__.columns - decode_exprs: list[pl.Expr] = [] casts: list[pl.Series] = [] + struct_keys: list[tuple[str, pl.Struct]] = [] + for key, schema in schemas.items(): - if key not in df.columns or key not in table_class.__table__.columns: + if isinstance(schema.dtype, pl.Struct): + # Cast any pickled flat leaf columns to their proper dtypes before + # reconstruction so Array/List fields have correct dtype. + for flat_col, leaf_dtype in flatten_struct_dtype(key, schema.dtype): + if flat_col not in df.columns or flat_col not in table_cols: + continue + if not self._is_pickled_sql_type(table_cols[flat_col].type): + continue + try: + casts.append(pl.Series(flat_col, df[flat_col].to_list(), dtype=leaf_dtype)) + except Exception: + continue + struct_keys.append((key, schema.dtype)) continue - column_type = table_class.__table__.columns[key].type - source_dtype = df.schema[key] - - if isinstance(schema.dtype, pl.Struct) and self._is_json_sql_type(column_type): - if source_dtype == pl.String: - decode_exprs.append(_struct_json_decode_expr(key, schema.dtype)) - elif source_dtype != schema.dtype: - casts.append(pl.Series(key, df[key].to_list(), dtype=schema.dtype)) + if key not in df.columns or key not in table_cols: continue - if not self._is_pickled_sql_type(column_type): + if not self._is_pickled_sql_type(table_cols[key].type): continue try: casts.append(pl.Series(key, df[key].to_list(), dtype=schema.dtype)) except Exception: - # Keep original dtype when values cannot be casted to the target schema. + # Keep original dtype when values cannot be cast to the target schema. continue - if decode_exprs: - df = df.with_columns(decode_exprs) if casts: df = df.with_columns(casts) + + # Reconstruct struct columns from their flat physical columns. + for key, dtype in struct_keys: + flat_cols = [fc for fc, _ in flatten_struct_dtype(key, dtype)] + present = [fc for fc in flat_cols if fc in df.columns] + if not present: + continue # struct was not part of this query; skip + missing = [fc for fc in flat_cols if fc not in df.columns] + if missing: + raise ValueError( + f"Struct attribute '{key}' is partially present in the DataFrame " + f"(missing: {missing}). Cannot reconstruct the struct column." + ) + df = df.with_columns(self._build_struct_expr(key, dtype).alias(key)).drop(flat_cols) + return df def _update_max_id_per_time(self) -> None: @@ -848,6 +812,24 @@ def filter( include_sources=include_sources, ) + def _flatten_attrs_for_write( + self, + attrs: dict[str, Any], + schemas: dict[str, AttrSchema], + ) -> dict[str, Any]: + """Expand struct-typed values into flat ``{leaf_col: value}`` pairs. + + Non-struct values are passed through unchanged. + """ + result: dict[str, Any] = {} + for key, value in attrs.items(): + schema = schemas.get(key) + if schema is not None and isinstance(schema.dtype, pl.Struct) and isinstance(value, dict): + result.update(flatten_struct_value(key, value, schema.dtype)) + else: + result[key] = value + return result + def add_node( self, attrs: dict[str, Any], @@ -907,6 +889,7 @@ def add_node( else: node_id = index + attrs = self._flatten_attrs_for_write(attrs, self._node_attr_schemas()) node = self.Node( node_id=node_id, **attrs, @@ -985,6 +968,8 @@ def bulk_add_nodes( node[DEFAULT_ATTR_KEYS.NODE_ID] = node_id node_ids.append(node_id) + node_schemas = self._node_attr_schemas() + nodes = [self._flatten_attrs_for_write(node, node_schemas) for node in nodes] self._chunked_sa_write(Session.bulk_insert_mappings, nodes, self.Node) if is_signal_on(self.node_added): @@ -1085,6 +1070,7 @@ def add_edge( if hasattr(target_id, "item"): target_id = target_id.item() + attrs = self._flatten_attrs_for_write(attrs, self._edge_attr_schemas()) edge = self.Edge( source_id=source_id, target_id=target_id, @@ -1138,9 +1124,12 @@ def bulk_add_edges( return [] return None + edge_schemas = self._edge_attr_schemas() for edge in edges: _data_numpy_to_native(edge) + edges = [self._flatten_attrs_for_write(edge, edge_schemas) for edge in edges] + if return_ids: with Session(self._engine) as session: result = session.execute(sa.insert(self.Edge).returning(self.Edge.edge_id), edges) @@ -1297,7 +1286,8 @@ def _get_neighbors( # all columns node_columns = [self.Node] else: - node_columns = [getattr(self.Node, key) for key in attr_keys] + # Expand struct logical keys to their flat physical columns. + node_columns = self._physical_cols_for_query(attr_keys, self.Node) query = session.query(getattr(self.Edge, node_key), *node_columns) query = query.join(self.Edge, getattr(self.Edge, neighbor_key) == self.Node.node_id) @@ -1489,9 +1479,9 @@ def node_attrs( if attr_keys is not None: # making them unique attr_keys = list(dict.fromkeys(attr_keys)) - + # Expand struct logical keys to their flat physical columns. query = query.with_only_columns( - *[getattr(self.Node, key) for key in attr_keys], + *self._physical_cols_for_query(attr_keys, self.Node), ) nodes_df = pl.read_database( @@ -1502,7 +1492,7 @@ def node_attrs( nodes_df = unpickle_bytes_columns(nodes_df) nodes_df = self._cast_array_columns(self.Node, nodes_df) - # indices are included by default and must be removed + # Select using logical keys (struct columns are now reconstructed). if attr_keys is not None: nodes_df = nodes_df.select([pl.col(c) for c in attr_keys]) else: @@ -1526,17 +1516,17 @@ def edge_attrs( query = sa.select(self.Edge) if attr_keys is not None: - attr_keys = set(attr_keys) + attr_keys = list(dict.fromkeys(attr_keys)) # we always return the source and target id by default - attr_keys.add(DEFAULT_ATTR_KEYS.EDGE_ID) - attr_keys.add(DEFAULT_ATTR_KEYS.EDGE_SOURCE) - attr_keys.add(DEFAULT_ATTR_KEYS.EDGE_TARGET) - attr_keys = list(attr_keys) + for id_key in [DEFAULT_ATTR_KEYS.EDGE_ID, DEFAULT_ATTR_KEYS.EDGE_SOURCE, DEFAULT_ATTR_KEYS.EDGE_TARGET]: + if id_key not in attr_keys: + attr_keys.append(id_key) LOG.info("Edge attribute keys: %s", attr_keys) + # Expand struct logical keys to their flat physical columns. query = query.with_only_columns( - *[getattr(self.Edge, key) for key in attr_keys], + *self._physical_cols_for_query(attr_keys, self.Edge), ) edges_df = pl.read_database( @@ -1549,7 +1539,9 @@ def edge_attrs( if unpack: edges_df = unpack_array_attrs(edges_df) - elif attr_keys is None: + elif attr_keys is not None: + edges_df = edges_df.select([pl.col(c) for c in attr_keys if c in edges_df.columns]) + else: edges_df = edges_df.select([pl.col(c) for c in self._edge_attr_schemas() if c in edges_df.columns]) return edges_df @@ -1560,6 +1552,24 @@ def _node_attr_schemas(self) -> dict[str, AttrSchema]: def _edge_attr_schemas(self) -> dict[str, AttrSchema]: return self.__edge_attr_schemas + def _physical_cols_for_query( + self, + logical_keys: Sequence[str], + table_class: type[DeclarativeBase], + ) -> list[Any]: + """Return SQLAlchemy column objects for *logical_keys*, expanding struct keys + into their flat physical leaf columns so the SQL query fetches all necessary data.""" + schemas = self._attr_schemas_for_table(table_class) + cols: list[Any] = [] + for key in logical_keys: + schema = schemas.get(key) + if schema is not None and isinstance(schema.dtype, pl.Struct): + for flat_col, _ in flatten_struct_dtype(key, schema.dtype): + cols.append(getattr(table_class, flat_col)) + else: + cols.append(getattr(table_class, key)) + return cols + def node_attr_keys(self, return_ids: bool = False) -> list[str]: """ Get the keys of the attributes of the nodes. @@ -1570,7 +1580,7 @@ def node_attr_keys(self, return_ids: bool = False) -> list[str]: Whether to include NODE_ID in the returned keys. Defaults to False. If True, NODE_ID will be included in the list. """ - keys = list(self.Node.__table__.columns.keys()) + keys = list(self._node_attr_schemas().keys()) if not return_ids and DEFAULT_ATTR_KEYS.NODE_ID in keys: keys.remove(DEFAULT_ATTR_KEYS.NODE_ID) return keys @@ -1585,7 +1595,7 @@ def edge_attr_keys(self, return_ids: bool = False) -> list[str]: Whether to include EDGE_ID, EDGE_SOURCE, and EDGE_TARGET in the returned keys. Defaults to False. If True, these ID fields will be included in the list. """ - keys = list(self.Edge.__table__.columns.keys()) + keys = list(self._edge_attr_schemas().keys()) if not return_ids: for id_key in [DEFAULT_ATTR_KEYS.EDGE_ID, DEFAULT_ATTR_KEYS.EDGE_SOURCE, DEFAULT_ATTR_KEYS.EDGE_TARGET]: if id_key in keys: @@ -1778,28 +1788,24 @@ def _sqlalchemy_type_inference(self, default_value: Any) -> TypeEngine: else: raise ValueError(f"Unsupported default value type: {type(default_value)}") - def _add_new_column( + def _add_physical_column( self, table_class: type[DeclarativeBase], - schema: AttrSchema, + col_name: str, + sa_type: Any, + default_value: Any, ) -> None: - # Convert polars dtype to SQLAlchemy type - sa_type = polars_dtype_to_sqlalchemy_type(schema.dtype) - - # Handle special cases for default value encoding - default_value = schema.default_value + """Create a single physical SQL column and register it on the ORM class.""" if isinstance(sa_type, sa.PickleType) and default_value is not None: - # Pickle complex types for database storage default_value = blob_default(self._engine, cloudpickle.dumps(default_value)) - sa_column = sa.Column(schema.key, sa_type, default=default_value) + sa_column = sa.Column(col_name, sa_type, default=default_value) str_dialect_type = sa_column.type.compile(dialect=self._engine.dialect) identifier_preparer = self._engine.dialect.identifier_preparer quoted_table_name = identifier_preparer.format_table(table_class.__table__) quoted_column_name = identifier_preparer.quote(sa_column.name) - # Properly quote default values based on type if isinstance(default_value, str): quoted_default = f"'{default_value}'" elif default_value is None: @@ -1814,16 +1820,40 @@ def _add_new_column( ) LOG.info("add %s column statement:\n'%s'", table_class.__table__, add_column_stmt) - # create the new column in the database with Session(self._engine) as session: session.execute(add_column_stmt) session.commit() - # register the new column in the Node class - setattr(table_class, schema.key, sa_column) + setattr(table_class, col_name, sa_column) table_class.__table__.append_column(sa_column) + def _add_new_column( + self, + table_class: type[DeclarativeBase], + schema: AttrSchema, + ) -> None: + """Add a new attribute column (or flat leaf columns for structs) to *table_class*.""" + if isinstance(schema.dtype, pl.Struct): + # Expand struct into one physical column per leaf field. + flat_defaults = flatten_struct_value(schema.key, schema.default_value or {}, schema.dtype) + for flat_col, leaf_dtype in flatten_struct_dtype(schema.key, schema.dtype): + self._add_physical_column( + table_class, + flat_col, + polars_dtype_to_sqlalchemy_type(leaf_dtype), + flat_defaults.get(flat_col), + ) + return + + self._add_physical_column( + table_class, + schema.key, + polars_dtype_to_sqlalchemy_type(schema.dtype), + schema.default_value, + ) + def _drop_column(self, table_class: type[DeclarativeBase], key: str) -> None: + """Drop a single physical column from *table_class*.""" identifier_preparer = self._engine.dialect.identifier_preparer quoted_table_name = identifier_preparer.format_table(table_class.__table__) quoted_column_name = identifier_preparer.quote(key) @@ -1860,7 +1890,12 @@ def remove_node_attr_key(self, key: str) -> None: raise ValueError(f"Cannot remove required node attribute key {key}") node_schemas = self.__node_attr_schemas - self._drop_column(self.Node, key) + schema = node_schemas.get(key) + if schema and isinstance(schema.dtype, pl.Struct): + for flat_col, _ in flatten_struct_dtype(key, schema.dtype): + self._drop_column(self.Node, flat_col) + else: + self._drop_column(self.Node, key) node_schemas.pop(key, None) self.__node_attr_schemas = node_schemas @@ -1884,7 +1919,12 @@ def remove_edge_attr_key(self, key: str) -> None: raise ValueError(f"Edge attribute key {key} does not exist") edge_schemas = self.__edge_attr_schemas - self._drop_column(self.Edge, key) + schema = edge_schemas.get(key) + if schema and isinstance(schema.dtype, pl.Struct): + for flat_col, _ in flatten_struct_dtype(key, schema.dtype): + self._drop_column(self.Edge, flat_col) + else: + self._drop_column(self.Edge, key) edge_schemas.pop(key, None) self.__edge_attr_schemas = edge_schemas @@ -1923,6 +1963,8 @@ def _update_table( # Handle array values with bulk_update_mappings attrs = attrs.copy() _data_numpy_to_native(attrs) + schemas = self._attr_schemas_for_table(table_class) + attrs = self._flatten_attrs_for_write(attrs, schemas) # specialized case for scalar values - use simple bulk update if all(np.isscalar(v) for v in attrs.values()): diff --git a/src/tracksdata/utils/_dtypes.py b/src/tracksdata/utils/_dtypes.py index dc3d6dd0..152b3d50 100644 --- a/src/tracksdata/utils/_dtypes.py +++ b/src/tracksdata/utils/_dtypes.py @@ -395,6 +395,71 @@ def infer_default_value_from_dtype(dtype: pl.DataType) -> Any: } +STRUCT_FIELD_SEP = "__" + + +def flatten_struct_dtype( + key: str, + dtype: pl.Struct, + sep: str = STRUCT_FIELD_SEP, +) -> list[tuple[str, pl.DataType]]: + """Recursively return ``(flat_column_name, leaf_dtype)`` for all leaves of a struct. + + Parameters + ---------- + key : str + The root column name (or already-accumulated flat prefix for nested calls). + dtype : pl.Struct + The struct dtype to flatten. + sep : str + Separator between path components. Defaults to ``STRUCT_FIELD_SEP``. + + Examples + -------- + >>> flatten_struct_dtype("m", pl.Struct({"score": pl.Int64, "label": pl.String})) + [("m__score", Int64), ("m__label", String)] + """ + results: list[tuple[str, pl.DataType]] = [] + for field_name, field_dtype in dtype.to_schema().items(): + flat_key = f"{key}{sep}{field_name}" + if isinstance(field_dtype, pl.Struct): + results.extend(flatten_struct_dtype(flat_key, field_dtype, sep)) + else: + results.append((flat_key, field_dtype)) + return results + + +def flatten_struct_value( + key: str, + value: dict, + dtype: pl.Struct, + sep: str = STRUCT_FIELD_SEP, +) -> dict: + """Flatten a struct dict value into ``{flat_col: scalar}`` pairs. + + Parameters + ---------- + key : str + The root column name. + value : dict + The struct value to flatten (may be ``None`` or empty). + dtype : pl.Struct + The struct dtype describing the expected fields. + sep : str + Separator. Defaults to ``STRUCT_FIELD_SEP``. + """ + result: dict = {} + value = value or {} + for field_name, field_dtype in dtype.to_schema().items(): + flat_key = f"{key}{sep}{field_name}" + field_val = value.get(field_name) + if isinstance(field_dtype, pl.Struct): + result.update(flatten_struct_value(flat_key, field_val or {}, field_dtype, sep)) + else: + result[flat_key] = field_val + return result + + def polars_dtype_to_sqlalchemy_type(dtype: pl.DataType) -> TypeEngine: """ Convert a polars dtype to SQLAlchemy type. @@ -416,10 +481,6 @@ def polars_dtype_to_sqlalchemy_type(dtype: pl.DataType) -> TypeEngine: >>> polars_dtype_to_sqlalchemy_type(pl.Boolean) """ - # Handle struct types as JSON for backend-level field filtering. - if isinstance(dtype, pl.Struct): - return sa.JSON() - # Handle sequence types - use PickleType for storage if isinstance(dtype, pl.Array | pl.List): return sa.PickleType() @@ -444,7 +505,6 @@ def polars_dtype_to_sqlalchemy_type(dtype: pl.DataType) -> TypeEngine: (sa.Float, pl.Float64), (sa.Text, pl.String), # Must come before String (sa.String, pl.String), - (sa.JSON, pl.Object), (sa.PickleType, pl.Object), # Must come before LargeBinary (sa.LargeBinary, pl.Object), ] From d9bee26769ca346e2e460c47c36bb655b291e54b Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Fri, 10 Apr 2026 13:02:53 +0300 Subject: [PATCH 18/30] removed codex config wrongly added --- .codex/environments/environment.toml | 10 ---------- 1 file changed, 10 deletions(-) delete mode 100644 .codex/environments/environment.toml diff --git a/.codex/environments/environment.toml b/.codex/environments/environment.toml deleted file mode 100644 index 1324ca94..00000000 --- a/.codex/environments/environment.toml +++ /dev/null @@ -1,10 +0,0 @@ -# THIS IS AUTOGENERATED. DO NOT EDIT MANUALLY -version = 1 -name = "tracksdata" - -[setup] -script = ''' -uv venv -uv pip install -e .[spatial,test,docs] -source .venv/bin/activate -''' From cc0beb4c0e76cedb0c418baf95dda9406a4fc0af Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Tue, 14 Apr 2026 16:44:19 +0900 Subject: [PATCH 19/30] issue fixes --- src/tracksdata/graph/_sql_graph.py | 50 ++++++++++++++++++++---------- 1 file changed, 34 insertions(+), 16 deletions(-) diff --git a/src/tracksdata/graph/_sql_graph.py b/src/tracksdata/graph/_sql_graph.py index 9dd4b570..2b5d7038 100644 --- a/src/tracksdata/graph/_sql_graph.py +++ b/src/tracksdata/graph/_sql_graph.py @@ -266,19 +266,19 @@ def node_attrs( schema_overrides=self._graph._polars_schema_override(self._graph.Node), ) - if attr_keys is not None: - nodes_attrs = nodes_attrs.select(attr_keys) - nodes_attrs = unpickle_bytes_columns(nodes_attrs) nodes_attrs = self._graph._cast_array_columns(self._graph.Node, nodes_attrs) + if attr_keys is not None: + nodes_attrs = nodes_attrs.select(attr_keys) + if unpack: nodes_attrs = unpack_array_attrs(nodes_attrs) return nodes_attrs - @staticmethod def _query_from_attr_keys( + self, query: sa.Select, table: type[DeclarativeBase], attr_keys: list[str] | None = None, @@ -292,14 +292,23 @@ def _query_from_attr_keys( LOG.info("Query attr_keys: %s", attr_keys) + schemas = self._graph._attr_schemas_for_table(table) + flat_names: list[str] = [] + for key in attr_keys: + schema = schemas.get(key) + if schema is not None and isinstance(schema.dtype, pl.Struct): + flat_names.extend(fc for fc, _ in flatten_struct_dtype(key, schema.dtype)) + else: + flat_names.append(key) + if isinstance(query, sa.CompoundSelect): union_query = query.alias("u") query = sa.select( - *[getattr(union_query.c, key) for key in attr_keys], + *[getattr(union_query.c, name) for name in flat_names], ) else: query = query.with_only_columns( - *[getattr(table, key) for key in attr_keys], + *[getattr(table, name) for name in flat_names], ) LOG.info("Query after attr_keys selection:\n%s", query) @@ -889,10 +898,10 @@ def add_node( else: node_id = index - attrs = self._flatten_attrs_for_write(attrs, self._node_attr_schemas()) + write_attrs = self._flatten_attrs_for_write(attrs, self._node_attr_schemas()) node = self.Node( node_id=node_id, - **attrs, + **write_attrs, ) with Session(self._engine) as session: @@ -969,8 +978,8 @@ def bulk_add_nodes( node_ids.append(node_id) node_schemas = self._node_attr_schemas() - nodes = [self._flatten_attrs_for_write(node, node_schemas) for node in nodes] - self._chunked_sa_write(Session.bulk_insert_mappings, nodes, self.Node) + write_nodes = [self._flatten_attrs_for_write(node, node_schemas) for node in nodes] + self._chunked_sa_write(Session.bulk_insert_mappings, write_nodes, self.Node) if is_signal_on(self.node_added): for node_id, node_attrs in zip(node_ids, nodes, strict=True): @@ -1004,7 +1013,10 @@ def remove_node(self, node_id: int) -> None: raise ValueError(f"Node {node_id} does not exist in the graph.") if is_signal_on(self.node_removed): - old_attrs = {key: getattr(node, key) for key in self.node_attr_keys()} + attr_keys = self.node_attr_keys() + old_df = self.filter(node_ids=[node_id]).node_attrs(attr_keys=attr_keys) + old_row = old_df.row(0, named=True) + old_attrs = {key: old_row[key] for key in attr_keys} # Remove all edges where this node is source or target session.query(self.Edge).filter( @@ -1628,13 +1640,19 @@ def _resolve_attr_keys( if len(attr_keys) == 0: raise ValueError("attr_keys must contain at least one column name") - missing = [key for key in attr_keys if key not in table_class.__table__.columns] + schemas = self._attr_schemas_for_table(table_class) + physical_names: list[str] = [] + for key in attr_keys: + schema = schemas.get(key) + if schema is not None and isinstance(schema.dtype, pl.Struct): + physical_names.extend(fc for fc, _ in flatten_struct_dtype(key, schema.dtype)) + else: + physical_names.append(key) + + missing = [name for name in physical_names if name not in table_class.__table__.columns] if missing: raise ValueError(f"Columns {missing} do not exist on table {table_class.__tablename__}") - resolved_columns = [getattr(table_class, key) for key in attr_keys] - - if isinstance(attr_keys, str): - attr_keys = [attr_keys] + resolved_columns = [getattr(table_class, name) for name in physical_names] cols_fragment = "_".join(attr_keys) name = f"ix_{table_class.__tablename__.lower()}_{cols_fragment}" From ffea2ecd8db4c44fbf437b5d29db73f9f030135e Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Tue, 14 Apr 2026 17:21:29 +0900 Subject: [PATCH 20/30] rolled back unncessary change --- src/tracksdata/graph/_sql_graph.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/src/tracksdata/graph/_sql_graph.py b/src/tracksdata/graph/_sql_graph.py index effa415b..f268efbf 100644 --- a/src/tracksdata/graph/_sql_graph.py +++ b/src/tracksdata/graph/_sql_graph.py @@ -266,12 +266,12 @@ def node_attrs( schema_overrides=self._graph._polars_schema_override(self._graph.Node), ) - nodes_attrs = unpickle_bytes_columns(nodes_attrs) - nodes_attrs = self._graph._cast_array_columns(self._graph.Node, nodes_attrs) - if attr_keys is not None: nodes_attrs = nodes_attrs.select(attr_keys) + nodes_attrs = unpickle_bytes_columns(nodes_attrs) + nodes_attrs = self._graph._cast_columns(self._graph.Node, nodes_attrs) + if unpack: nodes_attrs = unpack_array_attrs(nodes_attrs) @@ -336,7 +336,7 @@ def edge_attrs(self, attr_keys: list[str] | None = None, unpack: bool = False) - ) edges_df = unpickle_bytes_columns(edges_df) - edges_df = self._graph._cast_array_columns(self._graph.Edge, edges_df) + edges_df = self._graph._cast_columns(self._graph.Edge, edges_df) if unpack: edges_df = unpack_array_attrs(edges_df) @@ -747,7 +747,7 @@ def _build_struct_expr(key: str, dtype: pl.Struct) -> pl.Expr: fields.append(pl.col(flat_col).alias(field_name)) return pl.struct(fields) - def _cast_array_columns(self, table_class: type[DeclarativeBase], df: pl.DataFrame) -> pl.DataFrame: + def _cast_columns(self, table_class: type[DeclarativeBase], df: pl.DataFrame) -> pl.DataFrame: """Cast pickled columns to their target dtype and reconstruct struct columns.""" schemas = self._attr_schemas_for_table(table_class) table_cols = table_class.__table__.columns @@ -1327,7 +1327,7 @@ def _get_neighbors( self.Node, ) node_df = unpickle_bytes_columns(node_df) - node_df = self._cast_array_columns(self.Node, node_df) + node_df = self._cast_columns(self.Node, node_df) if single_node: if not return_attrs: @@ -1520,7 +1520,7 @@ def node_attrs( schema_overrides=self._polars_schema_override(self.Node), ) nodes_df = unpickle_bytes_columns(nodes_df) - nodes_df = self._cast_array_columns(self.Node, nodes_df) + nodes_df = self._cast_columns(self.Node, nodes_df) # Select using logical keys (struct columns are now reconstructed). if attr_keys is not None: @@ -1546,11 +1546,12 @@ def edge_attrs( query = sa.select(self.Edge) if attr_keys is not None: - attr_keys = list(dict.fromkeys(attr_keys)) + attr_keys = set(attr_keys) # we always return the source and target id by default - for id_key in [DEFAULT_ATTR_KEYS.EDGE_ID, DEFAULT_ATTR_KEYS.EDGE_SOURCE, DEFAULT_ATTR_KEYS.EDGE_TARGET]: - if id_key not in attr_keys: - attr_keys.append(id_key) + attr_keys.add(DEFAULT_ATTR_KEYS.EDGE_ID) + attr_keys.add(DEFAULT_ATTR_KEYS.EDGE_SOURCE) + attr_keys.add(DEFAULT_ATTR_KEYS.EDGE_TARGET) + attr_keys = list(attr_keys) LOG.info("Edge attribute keys: %s", attr_keys) @@ -1565,7 +1566,7 @@ def edge_attrs( schema_overrides=self._polars_schema_override(self.Edge), ) edges_df = unpickle_bytes_columns(edges_df) - edges_df = self._cast_array_columns(self.Edge, edges_df) + edges_df = self._cast_columns(self.Edge, edges_df) if unpack: edges_df = unpack_array_attrs(edges_df) From 5e7331f482fdaedd0a3163e2bb12c84d783316d6 Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Thu, 28 May 2026 14:11:04 +0900 Subject: [PATCH 21/30] additional comments --- src/tracksdata/attrs.py | 17 ++++++++++++++++- src/tracksdata/graph/_rustworkx_graph.py | 5 +++++ src/tracksdata/graph/_sql_graph.py | 19 ++++++++++++++++++- 3 files changed, 39 insertions(+), 2 deletions(-) diff --git a/src/tracksdata/attrs.py b/src/tracksdata/attrs.py index be223211..5a425b35 100644 --- a/src/tracksdata/attrs.py +++ b/src/tracksdata/attrs.py @@ -129,6 +129,9 @@ def __init__(self, attr: "Attr", op: Callable, other: ExprInput | MembershipExpr raise ValueError(f"Comparison operators are not supported for multiple columns. Found {columns}.") self.attr = attr + # Prefer the explicitly tracked root_column so struct-field comparisons + # (e.g. `NodeAttr("m").struct.field("x") == 1`) record the parent storage + # column ("m"), letting backends remap to their physical layout via field_path. self.column = attr.root_column if attr.root_column is not None else columns[0] self.op = op @@ -203,13 +206,21 @@ def __rge__(self, other: ExprInput) -> "Attr": ... class _StructNamespace: - """Wrapper around polars struct namespace that preserves Attr semantics.""" + """Wrapper around polars struct namespace that preserves Attr semantics. + + Polars' own ``Expr.struct.field(name)`` only updates the underlying expression; + it loses the parent column identity, which backends need to map a filter back + to its physical storage (e.g. SQL flat columns, dict lookups in rustworkx). + This wrapper proxies the namespace while threading ``root_column`` and + ``field_path`` through ``.field(...)`` calls. + """ def __init__(self, attr: "Attr") -> None: self._attr = attr self._namespace = attr.expr.struct def field(self, name: str) -> "Attr": + # preserve_field_path keeps the existing root/path before appending the new field. out = self._attr._wrap(self._namespace.field(name), preserve_field_path=True) if isinstance(out, Attr): out._append_field_path(name) @@ -811,4 +822,8 @@ def polars_reduce_attr_comps( # Return True for all rows by using the first column as a reference raise ValueError("No attribute comparisons provided.") + # Apply each comparison against the full Attr expression rather than the bare + # column from the dataframe. This matters for struct-field accesses: the + # expression already drills into the struct (e.g. `pl.col("m").struct.field("x")`), + # while `df[column]` would yield the whole struct and the comparison would fail. return pl.reduce(reduce_op, [attr_comp.op(attr_comp.attr.expr, attr_comp.other) for attr_comp in attr_comps]) diff --git a/src/tracksdata/graph/_rustworkx_graph.py b/src/tracksdata/graph/_rustworkx_graph.py index d956f0bd..7c29442e 100644 --- a/src/tracksdata/graph/_rustworkx_graph.py +++ b/src/tracksdata/graph/_rustworkx_graph.py @@ -89,6 +89,11 @@ def _create_filter_func( LOG.info(f"Creating filter function for {attr_comps}") def _extract_field_path(value: Any, field_path: tuple[str, ...]) -> Any: + # Rustworkx stores attributes as plain Python objects (typically dicts for + # struct attrs) rather than polars columns, so struct-field filters can't be + # pushed down into an expression — we walk the path manually here. We also + # accept sequence- and attribute-style access to keep this robust for users + # who pass nested dataclasses or tuples through the attr dict. for field in field_path: if value is None: return None diff --git a/src/tracksdata/graph/_sql_graph.py b/src/tracksdata/graph/_sql_graph.py index b707aa97..198b44df 100644 --- a/src/tracksdata/graph/_sql_graph.py +++ b/src/tracksdata/graph/_sql_graph.py @@ -70,6 +70,12 @@ def _resolve_attr_filter_column( For struct field paths (e.g. ``NodeAttr("m").struct.field("score")``), the field path is joined with ``STRUCT_FIELD_SEP`` to form the physical flat column name (e.g. ``m__score``), which is a native SQL column. + + Struct attributes are stored as one physical SQL column per leaf field + (not as JSON blobs) so that filtering on a struct field is a native SQL + predicate on a leaf column rather than a server-side JSON path lookup — + the parent ``Attr``'s ``root_column`` and ``field_path`` provide the + logical-to-physical mapping. """ if not attr_filter.attr.field_path: return getattr(table, str(attr_filter.column)) @@ -652,6 +658,8 @@ def _attr_schemas_from_metadata( # Compute the set of flat physical columns that belong to known struct schemas, # so the legacy fallback below does not register them as independent logical keys. + # Without this, reloading a DB with a struct attribute "m" would re-expose + # "m__score" / "m__label" as their own top-level attributes. known_flat_cols: set[str] = set() for schema in schemas.values(): if isinstance(schema.dtype, pl.Struct): @@ -1616,7 +1624,13 @@ def _physical_cols_for_query( table_class: type[DeclarativeBase], ) -> list[Any]: """Return SQLAlchemy column objects for *logical_keys*, expanding struct keys - into their flat physical leaf columns so the SQL query fetches all necessary data.""" + into their flat physical leaf columns so the SQL query fetches all necessary data. + + Logical keys are what the user sees (``"measurements"``); physical columns are + what actually exists in the table (``"measurements__score"``, ...). The two + diverge only for struct attributes; ``_cast_columns`` reassembles the struct + on the result DataFrame. + """ schemas = self._attr_schemas_for_table(table_class) cols: list[Any] = [] for key in logical_keys: @@ -1638,6 +1652,9 @@ def node_attr_keys(self, return_ids: bool = False) -> list[str]: Whether to include NODE_ID in the returned keys. Defaults to False. If True, NODE_ID will be included in the list. """ + # Read from schemas (logical keys), not __table__.columns — the latter exposes + # struct leaves (``measurements__score``) as separate keys, but the public API + # should only surface the parent ``measurements``. keys = list(self._node_attr_schemas().keys()) if not return_ids and DEFAULT_ATTR_KEYS.NODE_ID in keys: keys.remove(DEFAULT_ATTR_KEYS.NODE_ID) From 37f8fc9b6dade4885b64ba9834f67a210f4cd177 Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Mon, 1 Jun 2026 11:33:39 -0700 Subject: [PATCH 22/30] Fix lint: remove whitespace from blank lines --- src/tracksdata/graph/_sql_graph.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/tracksdata/graph/_sql_graph.py b/src/tracksdata/graph/_sql_graph.py index 5cfdb380..854d3d9f 100644 --- a/src/tracksdata/graph/_sql_graph.py +++ b/src/tracksdata/graph/_sql_graph.py @@ -85,7 +85,7 @@ def _resolve_attr_filter_column( flat_col = STRUCT_FIELD_SEP.join([str(attr_filter.column), *attr_filter.attr.field_path]) return getattr(table, flat_col) - + # Module-level (not methods) so they can be registered with ``weakref.finalize`` # without holding a bound reference to the owning object, which would prevent # it from ever being collected. @@ -157,7 +157,7 @@ def _close_id_set(id_set: "_SQLIDSet") -> None: except Exception as exc: LOG.debug("Failed to close _SQLIDSet: %s", exc) - + def _filter_query( query: sa.Select, table: type[DeclarativeBase], From 1842c5557aeae0045c622072814366370e53d0e1 Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Thu, 4 Jun 2026 01:29:12 -0700 Subject: [PATCH 23/30] fixes --- src/tracksdata/attrs.py | 4 ++-- src/tracksdata/utils/_dtypes.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/tracksdata/attrs.py b/src/tracksdata/attrs.py index 5a425b35..b0963543 100644 --- a/src/tracksdata/attrs.py +++ b/src/tracksdata/attrs.py @@ -222,8 +222,8 @@ def __init__(self, attr: "Attr") -> None: def field(self, name: str) -> "Attr": # preserve_field_path keeps the existing root/path before appending the new field. out = self._attr._wrap(self._namespace.field(name), preserve_field_path=True) - if isinstance(out, Attr): - out._append_field_path(name) + # _namespace.field() always returns a polars Expr, so _wrap always yields an Attr here. + out._append_field_path(name) return out def __getattr__(self, name: str) -> Any: diff --git a/src/tracksdata/utils/_dtypes.py b/src/tracksdata/utils/_dtypes.py index 152b3d50..5c25b710 100644 --- a/src/tracksdata/utils/_dtypes.py +++ b/src/tracksdata/utils/_dtypes.py @@ -431,7 +431,7 @@ def flatten_struct_dtype( def flatten_struct_value( key: str, - value: dict, + value: dict | None, dtype: pl.Struct, sep: str = STRUCT_FIELD_SEP, ) -> dict: @@ -454,7 +454,7 @@ def flatten_struct_value( flat_key = f"{key}{sep}{field_name}" field_val = value.get(field_name) if isinstance(field_dtype, pl.Struct): - result.update(flatten_struct_value(flat_key, field_val or {}, field_dtype, sep)) + result.update(flatten_struct_value(flat_key, field_val, field_dtype, sep)) else: result[flat_key] = field_val return result From db3528773eae93ff847d7eb4e0fdc24055afefeb Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Thu, 4 Jun 2026 16:27:14 -0700 Subject: [PATCH 24/30] refactor aligning main --- src/tracksdata/graph/_sql_graph.py | 387 ++++++-------------- src/tracksdata/graph/_test/test_subgraph.py | 135 ------- 2 files changed, 113 insertions(+), 409 deletions(-) diff --git a/src/tracksdata/graph/_sql_graph.py b/src/tracksdata/graph/_sql_graph.py index 854d3d9f..dccc70e0 100644 --- a/src/tracksdata/graph/_sql_graph.py +++ b/src/tracksdata/graph/_sql_graph.py @@ -1,7 +1,5 @@ import binascii import re -import uuid -import weakref from collections.abc import Callable, Sequence from enum import Enum from pathlib import Path @@ -86,78 +84,6 @@ def _resolve_attr_filter_column( return getattr(table, flat_col) -# Module-level (not methods) so they can be registered with ``weakref.finalize`` -# without holding a bound reference to the owning object, which would prevent -# it from ever being collected. -def _drop_scratch_table(engine: sa.Engine, table: sa.Table) -> None: - """Drop a scratch table, swallowing errors (e.g. at interpreter shutdown).""" - try: - table.drop(engine) - except Exception as exc: - LOG.debug("Failed to drop scratch table %s: %s", table.name, exc) - - -class _SQLIDSet: - """A set of ids usable in SQL ``IN`` clauses without overflowing bind limits. - - Small sets compile to inline ``col.in_([...])``; larger sets are materialized - into a per-instance scratch table on ``graph._engine`` and matched via - ``col.in_(SELECT id FROM scratch)``. The scratch table is a regular table - (not ``TEMPORARY``) so it is visible from any pool connection the filter - later uses; the caller drops it via :meth:`close` once the queries that - reference it are no longer needed. - - ``occurrences`` is the maximum number of times the id set will be expanded - in a single compiled statement (e.g. filtering both ``source_id`` and - ``target_id`` of an edge table counts as 2). The scratch-table cutoff is - divided by it so that ``len(ids) * occurrences`` stays safely under the - backend's bound-variable limit. - """ - - def __init__( - self, - graph: "SQLGraph", - ids: Sequence[int], - *, - occurrences: int = 1, - ) -> None: - if hasattr(ids, "tolist"): - ids = ids.tolist() - self._ids: list[int] = list(ids) - # Hold the engine, not the graph, so this set does not participate in - # the graph -> SQLFilter -> _SQLIDSet -> graph reference cycle. - # Otherwise the scratch table would only be dropped after Python's - # cycle GC runs, delaying cleanup in long-running processes. - self._engine = graph._engine - - limit = max(1, graph._sql_chunk_size() // max(1, occurrences)) - if len(self._ids) > limit: - self._scratch: sa.Table | None = graph._create_id_scratch_table(self._ids) - else: - self._scratch = None - - @property - def uses_scratch_table(self) -> bool: - return self._scratch is not None - - def in_clause(self, column: sa.ColumnElement) -> "sa.ColumnElement[bool]": - if self._scratch is None: - return column.in_(self._ids) - return column.in_(sa.select(self._scratch.c.id)) - - def close(self) -> None: - if self._scratch is not None: - _drop_scratch_table(self._engine, self._scratch) - self._scratch = None - - -def _close_id_set(id_set: "_SQLIDSet") -> None: - try: - id_set.close() - except Exception as exc: - LOG.debug("Failed to close _SQLIDSet: %s", exc) - - def _filter_query( query: sa.Select, table: type[DeclarativeBase], @@ -194,17 +120,6 @@ def _filter_query( class SQLFilter(BaseFilter): - """SQL-backed filter over an :class:`SQLGraph`. - - When ``node_ids`` is larger than the backend's bound-variable budget - (after accounting for how many ``IN (...)`` clauses the list expands - into), the filter materializes the ids into a per-instance scratch - table on ``graph._engine`` and references it via subselects. The - scratch table is dropped when the filter is garbage-collected (via - :func:`weakref.finalize`), so callers don't need to close the filter - explicitly. - """ - def __init__( self, *attr_filters: AttrComparison, @@ -218,7 +133,6 @@ def __init__( self._node_attr_comps, self._edge_attr_comps = split_attr_comps(attr_filters) self._include_targets = include_targets self._include_sources = include_sources - self._id_set: _SQLIDSet | None = None # creating initial query self._node_query: sa.Select = sa.select(self._graph.Node) @@ -226,20 +140,18 @@ def __init__( node_filtered = False if node_ids is not None: - # The node_ids list is expanded in up to three IN(...) clauses - # below (once on Node, plus once each on Edge.target_id / - # Edge.source_id unless the corresponding ``include_*`` is set). - # Account for that so the inline/scratch cutoff stays below the - # backend's bound-variable limit for the compiled statement. - occurrences = 1 + int(not self._include_targets) + int(not self._include_sources) - id_set = _SQLIDSet(self._graph, node_ids, occurrences=occurrences) - self._id_set = id_set - - self._node_query = self._node_query.filter(id_set.in_clause(self._graph.Node.node_id)) + if hasattr(node_ids, "tolist"): + node_ids = node_ids.tolist() + + self._node_query = self._node_query.filter(self._graph.Node.node_id.in_(node_ids)) if not self._include_targets: - self._edge_query = self._edge_query.filter(id_set.in_clause(self._graph.Edge.target_id)) + self._edge_query = self._edge_query.filter( + self._graph.Edge.target_id.in_(node_ids), + ) if not self._include_sources: - self._edge_query = self._edge_query.filter(id_set.in_clause(self._graph.Edge.source_id)) + self._edge_query = self._edge_query.filter( + self._graph.Edge.source_id.in_(node_ids), + ) node_filtered = True if self._node_attr_comps: @@ -320,16 +232,6 @@ def __init__( self._node_query = sa.union(*nodes_query) - # Drop the scratch table when this filter is collected. Only register a - # finalizer if one was actually allocated, so the common small-set case - # stays free of weakref bookkeeping. - if self._uses_scratch_table(): - weakref.finalize(self, _close_id_set, self._id_set) - - def _uses_scratch_table(self) -> bool: - """Whether the id set backing this filter materialized a scratch table.""" - return self._id_set is not None and self._id_set.uses_scratch_table - @cache_method def node_ids(self) -> list[int]: """ @@ -410,14 +312,7 @@ def _query_from_attr_keys( LOG.info("Query attr_keys: %s", attr_keys) - schemas = self._graph._attr_schemas_for_table(table) - flat_names: list[str] = [] - for key in attr_keys: - schema = schemas.get(key) - if schema is not None and isinstance(schema.dtype, pl.Struct): - flat_names.extend(fc for fc, _ in flatten_struct_dtype(key, schema.dtype)) - else: - flat_names.append(key) + flat_names = self._graph._physical_column_names(attr_keys, table) if isinstance(query, sa.CompoundSelect): union_query = query.alias("u") @@ -1716,13 +1611,24 @@ def _node_attr_schemas(self) -> dict[str, AttrSchema]: def _edge_attr_schemas(self) -> dict[str, AttrSchema]: return self.__edge_attr_schemas - def _physical_cols_for_query( + @staticmethod + def _leaf_column_names(key: str, schema: AttrSchema | None) -> list[str]: + """Physical column name(s) backing a single logical attribute *key*. + + Struct attributes are stored as one physical column per leaf field, so they + expand to ``["key__a", "key__b", ...]``; every other attribute maps to the + single column ``[key]``. + """ + if schema is not None and isinstance(schema.dtype, pl.Struct): + return [flat_col for flat_col, _ in flatten_struct_dtype(key, schema.dtype)] + return [key] + + def _physical_column_names( self, logical_keys: Sequence[str], table_class: type[DeclarativeBase], - ) -> list[Any]: - """Return SQLAlchemy column objects for *logical_keys*, expanding struct keys - into their flat physical leaf columns so the SQL query fetches all necessary data. + ) -> list[str]: + """Expand logical attribute keys to the physical column names backing them. Logical keys are what the user sees (``"measurements"``); physical columns are what actually exists in the table (``"measurements__score"``, ...). The two @@ -1730,15 +1636,15 @@ def _physical_cols_for_query( on the result DataFrame. """ schemas = self._attr_schemas_for_table(table_class) - cols: list[Any] = [] - for key in logical_keys: - schema = schemas.get(key) - if schema is not None and isinstance(schema.dtype, pl.Struct): - for flat_col, _ in flatten_struct_dtype(key, schema.dtype): - cols.append(getattr(table_class, flat_col)) - else: - cols.append(getattr(table_class, key)) - return cols + return [name for key in logical_keys for name in self._leaf_column_names(key, schemas.get(key))] + + def _physical_cols_for_query( + self, + logical_keys: Sequence[str], + table_class: type[DeclarativeBase], + ) -> list[Any]: + """Like :meth:`_physical_column_names`, but returning SQLAlchemy column objects.""" + return [getattr(table_class, name) for name in self._physical_column_names(logical_keys, table_class)] def node_attr_keys(self, return_ids: bool = False) -> list[str]: """ @@ -1801,14 +1707,7 @@ def _resolve_attr_keys( if len(attr_keys) == 0: raise ValueError("attr_keys must contain at least one column name") - schemas = self._attr_schemas_for_table(table_class) - physical_names: list[str] = [] - for key in attr_keys: - schema = schemas.get(key) - if schema is not None and isinstance(schema.dtype, pl.Struct): - physical_names.extend(fc for fc, _ in flatten_struct_dtype(key, schema.dtype)) - else: - physical_names.append(key) + physical_names = self._physical_column_names(attr_keys, table_class) missing = [name for name in physical_names if name not in table_class.__table__.columns] if missing: @@ -2045,6 +1944,11 @@ def _drop_column(self, table_class: type[DeclarativeBase], key: str) -> None: # refresh ORM schema to reflect database changes self._define_schema(overwrite=False) + def _drop_attr_columns(self, table_class: type[DeclarativeBase], key: str, schema: AttrSchema | None) -> None: + """Drop the physical column(s) backing a logical attribute *key* (one per struct leaf).""" + for name in self._leaf_column_names(key, schema): + self._drop_column(table_class, name) + def add_node_attr_key( self, key_or_schema: str | AttrSchema, @@ -2068,12 +1972,7 @@ def remove_node_attr_key(self, key: str) -> None: raise ValueError(f"Cannot remove required node attribute key {key}") node_schemas = self.__node_attr_schemas - schema = node_schemas.get(key) - if schema and isinstance(schema.dtype, pl.Struct): - for flat_col, _ in flatten_struct_dtype(key, schema.dtype): - self._drop_column(self.Node, flat_col) - else: - self._drop_column(self.Node, key) + self._drop_attr_columns(self.Node, key, node_schemas.get(key)) node_schemas.pop(key, None) self.__node_attr_schemas = node_schemas @@ -2097,12 +1996,7 @@ def remove_edge_attr_key(self, key: str) -> None: raise ValueError(f"Edge attribute key {key} does not exist") edge_schemas = self.__edge_attr_schemas - schema = edge_schemas.get(key) - if schema and isinstance(schema.dtype, pl.Struct): - for flat_col, _ in flatten_struct_dtype(key, schema.dtype): - self._drop_column(self.Edge, flat_col) - else: - self._drop_column(self.Edge, key) + self._drop_attr_columns(self.Edge, key, edge_schemas.get(key)) edge_schemas.pop(key, None) self.__edge_attr_schemas = edge_schemas @@ -2245,40 +2139,6 @@ def _chunked_sa_read( chunks.append(data_df) return pl.concat(chunks) - def _create_id_scratch_table(self, ids: Sequence[int]) -> sa.Table: - """Create a uniquely-named helper table holding ``ids`` on ``self._engine``. - - Used to work around SQL bound-variable limits when filtering by large - ``IN (...)`` lists: callers replace ``col.in_(ids)`` with - ``col.in_(sa.select(table.c.id))``. The table is a regular table on - the engine (not ``TEMPORARY``), so it is visible from any session or - connection drawn from the same engine pool — that is what makes it - usable across the multiple ``Session(engine)`` calls inside - :class:`SQLFilter`. - - The caller owns the table's lifetime and must eventually call - ``table.drop(self._engine)`` (or hand the table off to a finalizer - that does so) to remove it. - """ - unique_ids = list({int(v) for v in ids}) - - name = f"_tracksdata_ids_{uuid.uuid4().hex}" - table = sa.Table( - name, - sa.MetaData(), - sa.Column("id", sa.BigInteger, primary_key=True), - ) - table.create(self._engine) - - chunk_size = max(1, self._sql_chunk_size()) - with self._engine.begin() as conn: - for i in range(0, len(unique_ids), chunk_size): - conn.execute( - table.insert(), - [{"id": v} for v in unique_ids[i : i + chunk_size]], - ) - return table - def update_node_attrs( self, *, @@ -2373,22 +2233,13 @@ def _get_degree( with Session(self._engine) as session: return int(session.execute(stmt).scalar()) - base_stmt = sa.select(edge_key_col, sa.func.count()).group_by(edge_key_col) + stmt = sa.select(edge_key_col, sa.func.count()).group_by(edge_key_col) + if node_ids is not None: + stmt = stmt.where(edge_key_col.in_(node_ids)) - degree: dict[int, int] = {} with Session(self._engine) as session: - if node_ids is None: - degree.update(session.execute(base_stmt).all()) - else: - # Chunk the IN(...) so the bound-parameter count stays below - # the backend's limit (notably SQLite's - # ``SQLITE_MAX_VARIABLE_NUMBER``). Each chunk's group-by result - # is disjoint, so we can merge them with a simple dict update. - chunk_size = max(1, self._sql_chunk_size()) - for i in range(0, len(node_ids), chunk_size): - chunk = node_ids[i : i + chunk_size] - stmt = base_stmt.where(edge_key_col.in_(chunk)) - degree.update(session.execute(stmt).all()) + # get the number of edges for each using group by and count + degree = dict(session.execute(stmt).all()) if node_ids is None: # this is necessary to make sure it's the same order as node_ids @@ -2503,10 +2354,9 @@ def _sqlite_table_dump( reflection path then rebuilds the in-memory state. For filtered copies (``source_node_ids`` not ``None``) the selection - is materialized in a per-instance scratch table on the source engine - so the row filter joins instead of using an oversized ``IN (...)`` - clause that would hit SQLite's bound-parameter limit. The scratch - table is dropped in the ``finally`` block before returning. + is materialized in a temp table so the row filter joins instead of + using an oversized ``IN (...)`` clause that would hit SQLite's + bound-parameter limit. """ dst_database: str = kwargs["database"] dst_path = Path(dst_database) @@ -2522,80 +2372,69 @@ def _sqlite_table_dump( # escape the path safely via single-quote doubling. attach_path = dst_database.replace("'", "''") - if source_node_ids is None: - selected: sa.Table | None = None - else: - # Materialize the selection in a per-instance scratch table so the - # row filter joins instead of expanding into an oversized IN(...). - # The table lives on ``source_root._engine`` (visible from the - # ATTACH-ing connection) and is dropped in the outer ``finally``. - # - # We deliberately do not use a ``TEMPORARY`` table here even - # though this function holds a single connection. SQLAlchemy's - # ``Connection.close()`` only returns the underlying DB-API - # connection to the pool, it does not destroy it, so a TEMP table - # would survive into the next consumer of that same pooled SQLite - # connection. A regular table dropped explicitly avoids that. - selected = source_root._create_id_scratch_table(source_node_ids) - - try: - with source_root._engine.connect() as conn: - conn.exec_driver_sql(f"ATTACH DATABASE '{attach_path}' AS _td_dst") - try: - # 1. Replicate the source schema by replaying its DDL against - # the attached destination. ``sqlite_master.sql`` is NULL for - # auto-generated objects (e.g. PK indexes), which we skip; - # tables are created before indexes. ``_tracksdata_ids_*`` - # are internal scratch tables (this call's own ``selected`` - # plus any from live ``SQLFilter``s on the same engine) and - # must not be copied into the persisted destination. - ddl_rows = conn.exec_driver_sql( - "SELECT type, sql FROM main.sqlite_master " - "WHERE sql IS NOT NULL AND type IN ('table', 'index') " - "AND name NOT GLOB '_tracksdata_ids_*' " - "ORDER BY CASE type WHEN 'table' THEN 0 ELSE 1 END" - ).fetchall() - for _type, ddl in ddl_rows: - qualified = cls._SQLITE_DDL_QUALIFIER.sub(r"\g<1>_td_dst.", ddl, count=1) - conn.exec_driver_sql(qualified) - - # 2. Copy rows. The Metadata table is included verbatim — its - # SQL-private schema entries describe the columns we just - # cloned and so are valid for the destination as-is. - if selected is None: - for table_name in ("Node", "Edge", "Overlap", "Metadata"): - conn.exec_driver_sql( - f'INSERT INTO _td_dst."{table_name}" SELECT * FROM main."{table_name}"' - ) - else: - selected_subq = f'SELECT id FROM "{selected.name}"' - - conn.exec_driver_sql( - f'INSERT INTO _td_dst."Node" SELECT * FROM main."Node" WHERE node_id IN ({selected_subq})' - ) - conn.exec_driver_sql( - f'INSERT INTO _td_dst."Edge" SELECT * FROM main."Edge" ' - f"WHERE source_id IN ({selected_subq}) " - f"AND target_id IN ({selected_subq})" - ) - conn.exec_driver_sql( - f'INSERT INTO _td_dst."Overlap" SELECT * FROM main."Overlap" ' - f"WHERE source_id IN ({selected_subq}) " - f"AND target_id IN ({selected_subq})" + with source_root._engine.connect() as conn: + conn.exec_driver_sql(f"ATTACH DATABASE '{attach_path}' AS _td_dst") + try: + # 1. Replicate the source schema by replaying its DDL against + # the attached destination. ``sqlite_master.sql`` is NULL for + # auto-generated objects (e.g. PK indexes), which we skip; + # tables are created before indexes. + ddl_rows = conn.exec_driver_sql( + "SELECT type, sql FROM main.sqlite_master " + "WHERE sql IS NOT NULL AND type IN ('table', 'index') " + "ORDER BY CASE type WHEN 'table' THEN 0 ELSE 1 END" + ).fetchall() + for _type, ddl in ddl_rows: + qualified = cls._SQLITE_DDL_QUALIFIER.sub(r"\g<1>_td_dst.", ddl, count=1) + conn.exec_driver_sql(qualified) + + # 2. Copy rows. The Metadata table is included verbatim — its + # SQL-private schema entries describe the columns we just + # cloned and so are valid for the destination as-is. + if source_node_ids is None: + for table_name in ("Node", "Edge", "Overlap", "Metadata"): + conn.exec_driver_sql(f'INSERT INTO _td_dst."{table_name}" SELECT * FROM main."{table_name}"') + else: + node_ids = list(source_node_ids) + if hasattr(node_ids, "tolist"): + node_ids = node_ids.tolist() + # Materialize the selection in a temp table so the row + # filter joins instead of using an oversized IN(...) clause. + conn.exec_driver_sql("CREATE TEMP TABLE _td_selected (node_id INTEGER PRIMARY KEY)") + insert_stmt = sa.text("INSERT INTO _td_selected (node_id) VALUES (:node_id)") + chunk_size = max(1, source_root._sql_chunk_size()) + for i in range(0, len(node_ids), chunk_size): + batch = node_ids[i : i + chunk_size] + conn.execute( + insert_stmt, + [{"node_id": int(nid)} for nid in batch], ) - conn.exec_driver_sql('INSERT INTO _td_dst."Metadata" SELECT * FROM main."Metadata"') - - conn.commit() - finally: - conn.exec_driver_sql("DETACH DATABASE _td_dst") - - # 3. Open the destination from the now-populated file. The standard - # constructor reflects the schema, restores pickled column types, - # and recomputes ``_max_id_per_time``. - return cls(**kwargs) - finally: - if selected is not None: - _drop_scratch_table(source_root._engine, selected) + + conn.exec_driver_sql( + 'INSERT INTO _td_dst."Node" SELECT * FROM main."Node" ' + "WHERE node_id IN (SELECT node_id FROM _td_selected)" + ) + conn.exec_driver_sql( + 'INSERT INTO _td_dst."Edge" SELECT * FROM main."Edge" ' + "WHERE source_id IN (SELECT node_id FROM _td_selected) " + "AND target_id IN (SELECT node_id FROM _td_selected)" + ) + conn.exec_driver_sql( + 'INSERT INTO _td_dst."Overlap" SELECT * FROM main."Overlap" ' + "WHERE source_id IN (SELECT node_id FROM _td_selected) " + "AND target_id IN (SELECT node_id FROM _td_selected)" + ) + conn.exec_driver_sql('INSERT INTO _td_dst."Metadata" SELECT * FROM main."Metadata"') + conn.exec_driver_sql("DROP TABLE _td_selected") + + conn.commit() + finally: + conn.exec_driver_sql("DETACH DATABASE _td_dst") + + # 3. Open the destination from the now-populated file. The standard + # constructor reflects the schema, restores pickled column types, + # and recomputes ``_max_id_per_time``. + return cls(**kwargs) def __getstate__(self) -> dict: data_dict = self.__dict__.copy() diff --git a/src/tracksdata/graph/_test/test_subgraph.py b/src/tracksdata/graph/_test/test_subgraph.py index 6197eca8..f38e54c1 100644 --- a/src/tracksdata/graph/_test/test_subgraph.py +++ b/src/tracksdata/graph/_test/test_subgraph.py @@ -1,5 +1,3 @@ -import gc -import itertools import re from collections.abc import Callable from contextlib import contextmanager @@ -1339,136 +1337,3 @@ def test_edge_list(graph_backend: BaseGraph, use_subgraph: bool) -> None: ) ) assert edge_list == expected_edge_list - - -def _build_chain_graph(graph: SQLGraph, n_nodes: int) -> list[int]: - node_ids: list[int] = [] - for t in range(n_nodes): - node_ids.append(graph.add_node({DEFAULT_ATTR_KEYS.T: t})) - for src, tgt in itertools.pairwise(node_ids): - graph.add_edge(src, tgt, {}) - graph.add_overlap(node_ids[0], node_ids[1]) - graph.add_overlap(node_ids[2], node_ids[3]) - return node_ids - - -def _scratch_table_count(graph: SQLGraph) -> int: - """Count leftover ``_tracksdata_ids_*`` scratch tables in a SQLite graph. - - Scratch tables are regular (engine-wide) tables and live in - ``sqlite_master``; we also probe ``sqlite_temp_master`` to flag any - regression that creates a leaky ``TEMPORARY`` scratch table on a - pooled connection. - """ - import sqlalchemy as sa - - total = 0 - with graph._engine.connect() as conn: - for view in ("sqlite_master", "sqlite_temp_master"): - total += conn.execute( - sa.text(f"SELECT COUNT(*) FROM {view} WHERE type='table' AND name LIKE '_tracksdata_ids_%'") - ).scalar() - return total - - -def test_sql_graph_filter_large_node_ids(tmp_path, monkeypatch: pytest.MonkeyPatch) -> None: - """Filtering with more ids than SQLite's variable limit must not raise. - - Reproduces the ``OperationalError: too many SQL variables`` failure by - forcing the scratch-table code path via a tiny chunk size. ``overlaps`` - and ``_get_degree`` use the same chunk size to drive their chunked - ``IN(...)`` reads, so they exercise that path without allocating scratch - tables. - """ - graph = SQLGraph("sqlite", str(tmp_path / "scratch.db")) - n_nodes = 40 - node_ids = _build_chain_graph(graph, n_nodes) - - # Force the chunked / scratch-table paths on every call site by shrinking - # the chunk size well below ``n_nodes``. - monkeypatch.setattr(SQLGraph, "_sql_chunk_size", lambda self: 4) - - # ``overlaps`` and the degree helpers chunk via ``_chunked_sa_read`` and - # do not allocate scratch tables, so the count stays at zero. - assert _scratch_table_count(graph) == 0 - in_deg = graph.in_degree(node_ids) - assert _scratch_table_count(graph) == 0 - out_deg = graph.out_degree(node_ids) - assert _scratch_table_count(graph) == 0 - overlaps = graph.overlaps(node_ids) - assert _scratch_table_count(graph) == 0 - - assert sum(in_deg) == n_nodes - 1 - assert sum(out_deg) == n_nodes - 1 - assert sorted(map(tuple, overlaps)) == sorted([(node_ids[0], node_ids[1]), (node_ids[2], node_ids[3])]) - - filtered = graph.filter(node_ids=node_ids) - # The filter wraps node_ids in an _SQLIDSet, which must materialize to a - # scratch table given the forced tiny chunk size. - assert filtered._uses_scratch_table() - subgraph = filtered.subgraph() - assert subgraph.num_nodes() == n_nodes - assert subgraph.num_edges() == n_nodes - 1 - - # Once the filter is collected, the scratch table is dropped. - del filtered, subgraph - gc.collect() - assert _scratch_table_count(graph) == 0 - - -def test_sql_from_other_excludes_scratch_tables(tmp_path, monkeypatch: pytest.MonkeyPatch) -> None: - """``from_other`` over the SQLite attach-dump path must not leak scratch - tables into the destination DB. - - The filtered fast path creates a ``_tracksdata_ids_`` selection - table on the source engine; if the DDL replay does not exclude it the - destination ends up with the internal helper persisted alongside - ``Node`` / ``Edge`` / ``Overlap`` / ``Metadata``. - """ - import sqlalchemy as sa - - monkeypatch.setattr(SQLGraph, "_sql_chunk_size", lambda self: 1) - - src = SQLGraph("sqlite", str(tmp_path / "src.db")) - node_ids = _build_chain_graph(src, n_nodes=6) - - subgraph = src.filter(node_ids=node_ids).subgraph() - dst_db = tmp_path / "dst.db" - dst = SQLGraph.from_other(subgraph, drivername="sqlite", database=str(dst_db)) - - try: - with dst._engine.connect() as conn: - names = { - row[0] for row in conn.execute(sa.text("SELECT name FROM sqlite_master WHERE type='table'")).fetchall() - } - finally: - dst._engine.dispose() - - assert names == {"Node", "Edge", "Overlap", "Metadata"}, names - - -def test_sql_graph_filter_borderline_node_ids(tmp_path, monkeypatch: pytest.MonkeyPatch) -> None: - """The scratch cutoff must account for how many times ids appear per statement. - - With ``_sql_chunk_size() == 12`` and ``SQLFilter`` using ``occurrences=3``, - a list of 5 ids would compile to ~15 bound variables — above the limit — - even though ``len(node_ids) <= chunk_size``. The helper must still switch - to the scratch-table path in that band. - """ - graph = SQLGraph("sqlite", str(tmp_path / "scratch.db")) - n_nodes = 5 - node_ids = _build_chain_graph(graph, n_nodes) - - monkeypatch.setattr(SQLGraph, "_sql_chunk_size", lambda self: 12) - - filtered = graph.filter(node_ids=node_ids) - # 5 ids fits under chunk_size=12 inline, but with occurrences=3 the - # effective cutoff is 12 // 3 == 4, so scratch must kick in. - assert filtered._uses_scratch_table() - subgraph = filtered.subgraph() - assert subgraph.num_nodes() == n_nodes - assert subgraph.num_edges() == n_nodes - 1 - - del filtered, subgraph - gc.collect() - assert _scratch_table_count(graph) == 0 From 3759d9cbd0f02187bfdee1a3f8178950f65ca01d Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Thu, 4 Jun 2026 18:25:41 -0700 Subject: [PATCH 25/30] Restore scratch-table machinery and tests from main MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The previous commit (db35287) mistakenly deleted upstream/main's _SQLIDSet scratch-table machinery, _create_id_scratch_table, the out_degree/copy bound-variable handling, and the three scratch-table tests — they were diffed against a stale fork main and wrongly treated as PR-added code. Restore them verbatim from main; the struct-attr column-expansion simplification from the previous commit is kept. --- src/tracksdata/graph/_sql_graph.py | 313 +++++++++++++++----- src/tracksdata/graph/_test/test_subgraph.py | 135 +++++++++ 2 files changed, 368 insertions(+), 80 deletions(-) diff --git a/src/tracksdata/graph/_sql_graph.py b/src/tracksdata/graph/_sql_graph.py index dccc70e0..690e5696 100644 --- a/src/tracksdata/graph/_sql_graph.py +++ b/src/tracksdata/graph/_sql_graph.py @@ -1,5 +1,7 @@ import binascii import re +import uuid +import weakref from collections.abc import Callable, Sequence from enum import Enum from pathlib import Path @@ -84,6 +86,78 @@ def _resolve_attr_filter_column( return getattr(table, flat_col) +# Module-level (not methods) so they can be registered with ``weakref.finalize`` +# without holding a bound reference to the owning object, which would prevent +# it from ever being collected. +def _drop_scratch_table(engine: sa.Engine, table: sa.Table) -> None: + """Drop a scratch table, swallowing errors (e.g. at interpreter shutdown).""" + try: + table.drop(engine) + except Exception as exc: + LOG.debug("Failed to drop scratch table %s: %s", table.name, exc) + + +class _SQLIDSet: + """A set of ids usable in SQL ``IN`` clauses without overflowing bind limits. + + Small sets compile to inline ``col.in_([...])``; larger sets are materialized + into a per-instance scratch table on ``graph._engine`` and matched via + ``col.in_(SELECT id FROM scratch)``. The scratch table is a regular table + (not ``TEMPORARY``) so it is visible from any pool connection the filter + later uses; the caller drops it via :meth:`close` once the queries that + reference it are no longer needed. + + ``occurrences`` is the maximum number of times the id set will be expanded + in a single compiled statement (e.g. filtering both ``source_id`` and + ``target_id`` of an edge table counts as 2). The scratch-table cutoff is + divided by it so that ``len(ids) * occurrences`` stays safely under the + backend's bound-variable limit. + """ + + def __init__( + self, + graph: "SQLGraph", + ids: Sequence[int], + *, + occurrences: int = 1, + ) -> None: + if hasattr(ids, "tolist"): + ids = ids.tolist() + self._ids: list[int] = list(ids) + # Hold the engine, not the graph, so this set does not participate in + # the graph -> SQLFilter -> _SQLIDSet -> graph reference cycle. + # Otherwise the scratch table would only be dropped after Python's + # cycle GC runs, delaying cleanup in long-running processes. + self._engine = graph._engine + + limit = max(1, graph._sql_chunk_size() // max(1, occurrences)) + if len(self._ids) > limit: + self._scratch: sa.Table | None = graph._create_id_scratch_table(self._ids) + else: + self._scratch = None + + @property + def uses_scratch_table(self) -> bool: + return self._scratch is not None + + def in_clause(self, column: sa.ColumnElement) -> "sa.ColumnElement[bool]": + if self._scratch is None: + return column.in_(self._ids) + return column.in_(sa.select(self._scratch.c.id)) + + def close(self) -> None: + if self._scratch is not None: + _drop_scratch_table(self._engine, self._scratch) + self._scratch = None + + +def _close_id_set(id_set: "_SQLIDSet") -> None: + try: + id_set.close() + except Exception as exc: + LOG.debug("Failed to close _SQLIDSet: %s", exc) + + def _filter_query( query: sa.Select, table: type[DeclarativeBase], @@ -120,6 +194,17 @@ def _filter_query( class SQLFilter(BaseFilter): + """SQL-backed filter over an :class:`SQLGraph`. + + When ``node_ids`` is larger than the backend's bound-variable budget + (after accounting for how many ``IN (...)`` clauses the list expands + into), the filter materializes the ids into a per-instance scratch + table on ``graph._engine`` and references it via subselects. The + scratch table is dropped when the filter is garbage-collected (via + :func:`weakref.finalize`), so callers don't need to close the filter + explicitly. + """ + def __init__( self, *attr_filters: AttrComparison, @@ -133,6 +218,7 @@ def __init__( self._node_attr_comps, self._edge_attr_comps = split_attr_comps(attr_filters) self._include_targets = include_targets self._include_sources = include_sources + self._id_set: _SQLIDSet | None = None # creating initial query self._node_query: sa.Select = sa.select(self._graph.Node) @@ -140,18 +226,20 @@ def __init__( node_filtered = False if node_ids is not None: - if hasattr(node_ids, "tolist"): - node_ids = node_ids.tolist() - - self._node_query = self._node_query.filter(self._graph.Node.node_id.in_(node_ids)) + # The node_ids list is expanded in up to three IN(...) clauses + # below (once on Node, plus once each on Edge.target_id / + # Edge.source_id unless the corresponding ``include_*`` is set). + # Account for that so the inline/scratch cutoff stays below the + # backend's bound-variable limit for the compiled statement. + occurrences = 1 + int(not self._include_targets) + int(not self._include_sources) + id_set = _SQLIDSet(self._graph, node_ids, occurrences=occurrences) + self._id_set = id_set + + self._node_query = self._node_query.filter(id_set.in_clause(self._graph.Node.node_id)) if not self._include_targets: - self._edge_query = self._edge_query.filter( - self._graph.Edge.target_id.in_(node_ids), - ) + self._edge_query = self._edge_query.filter(id_set.in_clause(self._graph.Edge.target_id)) if not self._include_sources: - self._edge_query = self._edge_query.filter( - self._graph.Edge.source_id.in_(node_ids), - ) + self._edge_query = self._edge_query.filter(id_set.in_clause(self._graph.Edge.source_id)) node_filtered = True if self._node_attr_comps: @@ -232,6 +320,16 @@ def __init__( self._node_query = sa.union(*nodes_query) + # Drop the scratch table when this filter is collected. Only register a + # finalizer if one was actually allocated, so the common small-set case + # stays free of weakref bookkeeping. + if self._uses_scratch_table(): + weakref.finalize(self, _close_id_set, self._id_set) + + def _uses_scratch_table(self) -> bool: + """Whether the id set backing this filter materialized a scratch table.""" + return self._id_set is not None and self._id_set.uses_scratch_table + @cache_method def node_ids(self) -> list[int]: """ @@ -2139,6 +2237,40 @@ def _chunked_sa_read( chunks.append(data_df) return pl.concat(chunks) + def _create_id_scratch_table(self, ids: Sequence[int]) -> sa.Table: + """Create a uniquely-named helper table holding ``ids`` on ``self._engine``. + + Used to work around SQL bound-variable limits when filtering by large + ``IN (...)`` lists: callers replace ``col.in_(ids)`` with + ``col.in_(sa.select(table.c.id))``. The table is a regular table on + the engine (not ``TEMPORARY``), so it is visible from any session or + connection drawn from the same engine pool — that is what makes it + usable across the multiple ``Session(engine)`` calls inside + :class:`SQLFilter`. + + The caller owns the table's lifetime and must eventually call + ``table.drop(self._engine)`` (or hand the table off to a finalizer + that does so) to remove it. + """ + unique_ids = list({int(v) for v in ids}) + + name = f"_tracksdata_ids_{uuid.uuid4().hex}" + table = sa.Table( + name, + sa.MetaData(), + sa.Column("id", sa.BigInteger, primary_key=True), + ) + table.create(self._engine) + + chunk_size = max(1, self._sql_chunk_size()) + with self._engine.begin() as conn: + for i in range(0, len(unique_ids), chunk_size): + conn.execute( + table.insert(), + [{"id": v} for v in unique_ids[i : i + chunk_size]], + ) + return table + def update_node_attrs( self, *, @@ -2233,13 +2365,22 @@ def _get_degree( with Session(self._engine) as session: return int(session.execute(stmt).scalar()) - stmt = sa.select(edge_key_col, sa.func.count()).group_by(edge_key_col) - if node_ids is not None: - stmt = stmt.where(edge_key_col.in_(node_ids)) + base_stmt = sa.select(edge_key_col, sa.func.count()).group_by(edge_key_col) + degree: dict[int, int] = {} with Session(self._engine) as session: - # get the number of edges for each using group by and count - degree = dict(session.execute(stmt).all()) + if node_ids is None: + degree.update(session.execute(base_stmt).all()) + else: + # Chunk the IN(...) so the bound-parameter count stays below + # the backend's limit (notably SQLite's + # ``SQLITE_MAX_VARIABLE_NUMBER``). Each chunk's group-by result + # is disjoint, so we can merge them with a simple dict update. + chunk_size = max(1, self._sql_chunk_size()) + for i in range(0, len(node_ids), chunk_size): + chunk = node_ids[i : i + chunk_size] + stmt = base_stmt.where(edge_key_col.in_(chunk)) + degree.update(session.execute(stmt).all()) if node_ids is None: # this is necessary to make sure it's the same order as node_ids @@ -2354,9 +2495,10 @@ def _sqlite_table_dump( reflection path then rebuilds the in-memory state. For filtered copies (``source_node_ids`` not ``None``) the selection - is materialized in a temp table so the row filter joins instead of - using an oversized ``IN (...)`` clause that would hit SQLite's - bound-parameter limit. + is materialized in a per-instance scratch table on the source engine + so the row filter joins instead of using an oversized ``IN (...)`` + clause that would hit SQLite's bound-parameter limit. The scratch + table is dropped in the ``finally`` block before returning. """ dst_database: str = kwargs["database"] dst_path = Path(dst_database) @@ -2372,69 +2514,80 @@ def _sqlite_table_dump( # escape the path safely via single-quote doubling. attach_path = dst_database.replace("'", "''") - with source_root._engine.connect() as conn: - conn.exec_driver_sql(f"ATTACH DATABASE '{attach_path}' AS _td_dst") - try: - # 1. Replicate the source schema by replaying its DDL against - # the attached destination. ``sqlite_master.sql`` is NULL for - # auto-generated objects (e.g. PK indexes), which we skip; - # tables are created before indexes. - ddl_rows = conn.exec_driver_sql( - "SELECT type, sql FROM main.sqlite_master " - "WHERE sql IS NOT NULL AND type IN ('table', 'index') " - "ORDER BY CASE type WHEN 'table' THEN 0 ELSE 1 END" - ).fetchall() - for _type, ddl in ddl_rows: - qualified = cls._SQLITE_DDL_QUALIFIER.sub(r"\g<1>_td_dst.", ddl, count=1) - conn.exec_driver_sql(qualified) - - # 2. Copy rows. The Metadata table is included verbatim — its - # SQL-private schema entries describe the columns we just - # cloned and so are valid for the destination as-is. - if source_node_ids is None: - for table_name in ("Node", "Edge", "Overlap", "Metadata"): - conn.exec_driver_sql(f'INSERT INTO _td_dst."{table_name}" SELECT * FROM main."{table_name}"') - else: - node_ids = list(source_node_ids) - if hasattr(node_ids, "tolist"): - node_ids = node_ids.tolist() - # Materialize the selection in a temp table so the row - # filter joins instead of using an oversized IN(...) clause. - conn.exec_driver_sql("CREATE TEMP TABLE _td_selected (node_id INTEGER PRIMARY KEY)") - insert_stmt = sa.text("INSERT INTO _td_selected (node_id) VALUES (:node_id)") - chunk_size = max(1, source_root._sql_chunk_size()) - for i in range(0, len(node_ids), chunk_size): - batch = node_ids[i : i + chunk_size] - conn.execute( - insert_stmt, - [{"node_id": int(nid)} for nid in batch], + if source_node_ids is None: + selected: sa.Table | None = None + else: + # Materialize the selection in a per-instance scratch table so the + # row filter joins instead of expanding into an oversized IN(...). + # The table lives on ``source_root._engine`` (visible from the + # ATTACH-ing connection) and is dropped in the outer ``finally``. + # + # We deliberately do not use a ``TEMPORARY`` table here even + # though this function holds a single connection. SQLAlchemy's + # ``Connection.close()`` only returns the underlying DB-API + # connection to the pool, it does not destroy it, so a TEMP table + # would survive into the next consumer of that same pooled SQLite + # connection. A regular table dropped explicitly avoids that. + selected = source_root._create_id_scratch_table(source_node_ids) + + try: + with source_root._engine.connect() as conn: + conn.exec_driver_sql(f"ATTACH DATABASE '{attach_path}' AS _td_dst") + try: + # 1. Replicate the source schema by replaying its DDL against + # the attached destination. ``sqlite_master.sql`` is NULL for + # auto-generated objects (e.g. PK indexes), which we skip; + # tables are created before indexes. ``_tracksdata_ids_*`` + # are internal scratch tables (this call's own ``selected`` + # plus any from live ``SQLFilter``s on the same engine) and + # must not be copied into the persisted destination. + ddl_rows = conn.exec_driver_sql( + "SELECT type, sql FROM main.sqlite_master " + "WHERE sql IS NOT NULL AND type IN ('table', 'index') " + "AND name NOT GLOB '_tracksdata_ids_*' " + "ORDER BY CASE type WHEN 'table' THEN 0 ELSE 1 END" + ).fetchall() + for _type, ddl in ddl_rows: + qualified = cls._SQLITE_DDL_QUALIFIER.sub(r"\g<1>_td_dst.", ddl, count=1) + conn.exec_driver_sql(qualified) + + # 2. Copy rows. The Metadata table is included verbatim — its + # SQL-private schema entries describe the columns we just + # cloned and so are valid for the destination as-is. + if selected is None: + for table_name in ("Node", "Edge", "Overlap", "Metadata"): + conn.exec_driver_sql( + f'INSERT INTO _td_dst."{table_name}" SELECT * FROM main."{table_name}"' + ) + else: + selected_subq = f'SELECT id FROM "{selected.name}"' + + conn.exec_driver_sql( + f'INSERT INTO _td_dst."Node" SELECT * FROM main."Node" WHERE node_id IN ({selected_subq})' ) - - conn.exec_driver_sql( - 'INSERT INTO _td_dst."Node" SELECT * FROM main."Node" ' - "WHERE node_id IN (SELECT node_id FROM _td_selected)" - ) - conn.exec_driver_sql( - 'INSERT INTO _td_dst."Edge" SELECT * FROM main."Edge" ' - "WHERE source_id IN (SELECT node_id FROM _td_selected) " - "AND target_id IN (SELECT node_id FROM _td_selected)" - ) - conn.exec_driver_sql( - 'INSERT INTO _td_dst."Overlap" SELECT * FROM main."Overlap" ' - "WHERE source_id IN (SELECT node_id FROM _td_selected) " - "AND target_id IN (SELECT node_id FROM _td_selected)" - ) - conn.exec_driver_sql('INSERT INTO _td_dst."Metadata" SELECT * FROM main."Metadata"') - conn.exec_driver_sql("DROP TABLE _td_selected") - - conn.commit() - finally: - conn.exec_driver_sql("DETACH DATABASE _td_dst") - - # 3. Open the destination from the now-populated file. The standard - # constructor reflects the schema, restores pickled column types, - # and recomputes ``_max_id_per_time``. - return cls(**kwargs) + conn.exec_driver_sql( + f'INSERT INTO _td_dst."Edge" SELECT * FROM main."Edge" ' + f"WHERE source_id IN ({selected_subq}) " + f"AND target_id IN ({selected_subq})" + ) + conn.exec_driver_sql( + f'INSERT INTO _td_dst."Overlap" SELECT * FROM main."Overlap" ' + f"WHERE source_id IN ({selected_subq}) " + f"AND target_id IN ({selected_subq})" + ) + conn.exec_driver_sql('INSERT INTO _td_dst."Metadata" SELECT * FROM main."Metadata"') + + conn.commit() + finally: + conn.exec_driver_sql("DETACH DATABASE _td_dst") + + # 3. Open the destination from the now-populated file. The standard + # constructor reflects the schema, restores pickled column types, + # and recomputes ``_max_id_per_time``. + return cls(**kwargs) + finally: + if selected is not None: + _drop_scratch_table(source_root._engine, selected) def __getstate__(self) -> dict: data_dict = self.__dict__.copy() diff --git a/src/tracksdata/graph/_test/test_subgraph.py b/src/tracksdata/graph/_test/test_subgraph.py index f38e54c1..6197eca8 100644 --- a/src/tracksdata/graph/_test/test_subgraph.py +++ b/src/tracksdata/graph/_test/test_subgraph.py @@ -1,3 +1,5 @@ +import gc +import itertools import re from collections.abc import Callable from contextlib import contextmanager @@ -1337,3 +1339,136 @@ def test_edge_list(graph_backend: BaseGraph, use_subgraph: bool) -> None: ) ) assert edge_list == expected_edge_list + + +def _build_chain_graph(graph: SQLGraph, n_nodes: int) -> list[int]: + node_ids: list[int] = [] + for t in range(n_nodes): + node_ids.append(graph.add_node({DEFAULT_ATTR_KEYS.T: t})) + for src, tgt in itertools.pairwise(node_ids): + graph.add_edge(src, tgt, {}) + graph.add_overlap(node_ids[0], node_ids[1]) + graph.add_overlap(node_ids[2], node_ids[3]) + return node_ids + + +def _scratch_table_count(graph: SQLGraph) -> int: + """Count leftover ``_tracksdata_ids_*`` scratch tables in a SQLite graph. + + Scratch tables are regular (engine-wide) tables and live in + ``sqlite_master``; we also probe ``sqlite_temp_master`` to flag any + regression that creates a leaky ``TEMPORARY`` scratch table on a + pooled connection. + """ + import sqlalchemy as sa + + total = 0 + with graph._engine.connect() as conn: + for view in ("sqlite_master", "sqlite_temp_master"): + total += conn.execute( + sa.text(f"SELECT COUNT(*) FROM {view} WHERE type='table' AND name LIKE '_tracksdata_ids_%'") + ).scalar() + return total + + +def test_sql_graph_filter_large_node_ids(tmp_path, monkeypatch: pytest.MonkeyPatch) -> None: + """Filtering with more ids than SQLite's variable limit must not raise. + + Reproduces the ``OperationalError: too many SQL variables`` failure by + forcing the scratch-table code path via a tiny chunk size. ``overlaps`` + and ``_get_degree`` use the same chunk size to drive their chunked + ``IN(...)`` reads, so they exercise that path without allocating scratch + tables. + """ + graph = SQLGraph("sqlite", str(tmp_path / "scratch.db")) + n_nodes = 40 + node_ids = _build_chain_graph(graph, n_nodes) + + # Force the chunked / scratch-table paths on every call site by shrinking + # the chunk size well below ``n_nodes``. + monkeypatch.setattr(SQLGraph, "_sql_chunk_size", lambda self: 4) + + # ``overlaps`` and the degree helpers chunk via ``_chunked_sa_read`` and + # do not allocate scratch tables, so the count stays at zero. + assert _scratch_table_count(graph) == 0 + in_deg = graph.in_degree(node_ids) + assert _scratch_table_count(graph) == 0 + out_deg = graph.out_degree(node_ids) + assert _scratch_table_count(graph) == 0 + overlaps = graph.overlaps(node_ids) + assert _scratch_table_count(graph) == 0 + + assert sum(in_deg) == n_nodes - 1 + assert sum(out_deg) == n_nodes - 1 + assert sorted(map(tuple, overlaps)) == sorted([(node_ids[0], node_ids[1]), (node_ids[2], node_ids[3])]) + + filtered = graph.filter(node_ids=node_ids) + # The filter wraps node_ids in an _SQLIDSet, which must materialize to a + # scratch table given the forced tiny chunk size. + assert filtered._uses_scratch_table() + subgraph = filtered.subgraph() + assert subgraph.num_nodes() == n_nodes + assert subgraph.num_edges() == n_nodes - 1 + + # Once the filter is collected, the scratch table is dropped. + del filtered, subgraph + gc.collect() + assert _scratch_table_count(graph) == 0 + + +def test_sql_from_other_excludes_scratch_tables(tmp_path, monkeypatch: pytest.MonkeyPatch) -> None: + """``from_other`` over the SQLite attach-dump path must not leak scratch + tables into the destination DB. + + The filtered fast path creates a ``_tracksdata_ids_`` selection + table on the source engine; if the DDL replay does not exclude it the + destination ends up with the internal helper persisted alongside + ``Node`` / ``Edge`` / ``Overlap`` / ``Metadata``. + """ + import sqlalchemy as sa + + monkeypatch.setattr(SQLGraph, "_sql_chunk_size", lambda self: 1) + + src = SQLGraph("sqlite", str(tmp_path / "src.db")) + node_ids = _build_chain_graph(src, n_nodes=6) + + subgraph = src.filter(node_ids=node_ids).subgraph() + dst_db = tmp_path / "dst.db" + dst = SQLGraph.from_other(subgraph, drivername="sqlite", database=str(dst_db)) + + try: + with dst._engine.connect() as conn: + names = { + row[0] for row in conn.execute(sa.text("SELECT name FROM sqlite_master WHERE type='table'")).fetchall() + } + finally: + dst._engine.dispose() + + assert names == {"Node", "Edge", "Overlap", "Metadata"}, names + + +def test_sql_graph_filter_borderline_node_ids(tmp_path, monkeypatch: pytest.MonkeyPatch) -> None: + """The scratch cutoff must account for how many times ids appear per statement. + + With ``_sql_chunk_size() == 12`` and ``SQLFilter`` using ``occurrences=3``, + a list of 5 ids would compile to ~15 bound variables — above the limit — + even though ``len(node_ids) <= chunk_size``. The helper must still switch + to the scratch-table path in that band. + """ + graph = SQLGraph("sqlite", str(tmp_path / "scratch.db")) + n_nodes = 5 + node_ids = _build_chain_graph(graph, n_nodes) + + monkeypatch.setattr(SQLGraph, "_sql_chunk_size", lambda self: 12) + + filtered = graph.filter(node_ids=node_ids) + # 5 ids fits under chunk_size=12 inline, but with occurrences=3 the + # effective cutoff is 12 // 3 == 4, so scratch must kick in. + assert filtered._uses_scratch_table() + subgraph = filtered.subgraph() + assert subgraph.num_nodes() == n_nodes + assert subgraph.num_edges() == n_nodes - 1 + + del filtered, subgraph + gc.collect() + assert _scratch_table_count(graph) == 0 From 0d76262b0bf6c681c7648474373244a803258174 Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Mon, 8 Jun 2026 01:41:35 -0700 Subject: [PATCH 26/30] ignored the devcontaienr --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index b652a88f..2695912e 100644 --- a/.gitignore +++ b/.gitignore @@ -198,3 +198,4 @@ src/tracksdata/__about__.py # Claude .claude/ +.devcontainer/ From 87707d03cb28f8f07797c043c69934fa3fde241f Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Mon, 8 Jun 2026 02:14:23 -0700 Subject: [PATCH 27/30] bugfix --- src/tracksdata/graph/_sql_graph.py | 69 +++++++++---------- .../graph/_test/test_graph_backends.py | 10 +++ 2 files changed, 41 insertions(+), 38 deletions(-) diff --git a/src/tracksdata/graph/_sql_graph.py b/src/tracksdata/graph/_sql_graph.py index 690e5696..55d00ad6 100644 --- a/src/tracksdata/graph/_sql_graph.py +++ b/src/tracksdata/graph/_sql_graph.py @@ -377,24 +377,28 @@ def node_attrs( attr_keys=attr_keys, ) - with Session(self._graph._engine) as session: - nodes_attrs = pl.read_database( - self._graph._raw_query(query), - connection=session.connection(), - schema_overrides=self._graph._polars_schema_override(self._graph.Node), - ) + nodes_attrs = self._read_attr_dataframe(query, self._graph.Node) if attr_keys is not None: + attr_keys = list(dict.fromkeys(attr_keys)) nodes_attrs = nodes_attrs.select(attr_keys) - nodes_attrs = unpickle_bytes_columns(nodes_attrs) - nodes_attrs = self._graph._cast_columns(self._graph.Node, nodes_attrs) - if unpack: nodes_attrs = unpack_array_attrs(nodes_attrs) return nodes_attrs + def _read_attr_dataframe(self, query: sa.Select, table: type[DeclarativeBase]) -> pl.DataFrame: + with Session(self._graph._engine) as session: + df = pl.read_database( + self._graph._raw_query(query), + connection=session.connection(), + schema_overrides=self._graph._polars_schema_override(table), + ) + + df = unpickle_bytes_columns(df) + return self._graph._cast_columns(table, df) + def _query_from_attr_keys( self, query: sa.Select, @@ -439,15 +443,7 @@ def edge_attrs(self, attr_keys: list[str] | None = None, unpack: bool = False) - ], ) - with Session(self._graph._engine) as session: - edges_df = pl.read_database( - self._graph._raw_query(query), - connection=session.connection(), - schema_overrides=self._graph._polars_schema_override(self._graph.Edge), - ) - - edges_df = unpickle_bytes_columns(edges_df) - edges_df = self._graph._cast_columns(self._graph.Edge, edges_df) + edges_df = self._read_attr_dataframe(query, self._graph.Edge) if unpack: edges_df = unpack_array_attrs(edges_df) @@ -492,26 +488,23 @@ def subgraph( ], ) - with Session(self._graph._engine) as session: - node_query = session.execute(node_query) - edge_query = session.execute(edge_query) - - node_map_to_root = {} - node_map_from_root = {} - rx_graph = rx.PyDiGraph() - - for row in node_query.mappings().all(): - data = dict(row) - root_node_id = data.pop(DEFAULT_ATTR_KEYS.NODE_ID) - node_id = rx_graph.add_node(data) - node_map_to_root[node_id] = root_node_id - node_map_from_root[root_node_id] = node_id - - for row in edge_query.mappings().all(): - data = dict(row) - source_id = node_map_from_root[data.pop(DEFAULT_ATTR_KEYS.EDGE_SOURCE)] - target_id = node_map_from_root[data.pop(DEFAULT_ATTR_KEYS.EDGE_TARGET)] - rx_graph.add_edge(source_id, target_id, data) + nodes_df = self._read_attr_dataframe(node_query, self._graph.Node) + edges_df = self._read_attr_dataframe(edge_query, self._graph.Edge) + + node_map_to_root = {} + node_map_from_root = {} + rx_graph = rx.PyDiGraph() + + for data in nodes_df.iter_rows(named=True): + root_node_id = data.pop(DEFAULT_ATTR_KEYS.NODE_ID) + node_id = rx_graph.add_node(data) + node_map_to_root[node_id] = root_node_id + node_map_from_root[root_node_id] = node_id + + for data in edges_df.iter_rows(named=True): + source_id = node_map_from_root[data.pop(DEFAULT_ATTR_KEYS.EDGE_SOURCE)] + target_id = node_map_from_root[data.pop(DEFAULT_ATTR_KEYS.EDGE_TARGET)] + rx_graph.add_edge(source_id, target_id, data) graph = GraphView( rx_graph=rx_graph, diff --git a/src/tracksdata/graph/_test/test_graph_backends.py b/src/tracksdata/graph/_test/test_graph_backends.py index dfc6350f..69e9bc09 100644 --- a/src/tracksdata/graph/_test/test_graph_backends.py +++ b/src/tracksdata/graph/_test/test_graph_backends.py @@ -239,6 +239,16 @@ def test_filter_nodes_by_struct_field(graph_backend: BaseGraph) -> None: name_nodes = graph_backend.filter(NodeAttr("measurements").struct.field("name") == "B").node_ids() assert set(name_nodes) == {node_b} + measurements = graph_backend.filter(node_ids=[node_a, node_c]).node_attrs(attr_keys=["measurements"]) + assert measurements.schema["measurements"] == pl.Struct({"score": pl.Int64, "name": pl.String}) + assert {m["name"] for m in measurements["measurements"].to_list()} == {"A", "C"} + + subgraph = graph_backend.filter(NodeAttr("measurements").struct.field("score") == 1).subgraph( + node_attr_keys=["measurements"] + ) + subgraph_measurements = subgraph.node_attrs(attr_keys=["measurements"]) + assert {m["name"] for m in subgraph_measurements["measurements"].to_list()} == {"A", "C"} + def test_time_points(graph_backend: BaseGraph) -> None: """Test retrieving time points.""" From 9ba99e1e1f931cc95a4bad0ca59adbdd137e4796 Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Wed, 10 Jun 2026 07:01:43 -0700 Subject: [PATCH 28/30] Store Mask as a struct attribute instead of pickled pl.Object Masks were registered as pl.Object and round-tripped through pickle, leaving the bbox locked inside an opaque blob. They are now stored as pl.Struct({min_(z)yx, max_(z)yx: Int64, data: Binary}) so bbox fields are natively filterable via NodeAttr("mask").struct.field(...), while the binary mask stays blosc2-compressed in the data field. - Mask.struct_dtype() / to_struct() / from_struct() conversion API, plus as_mask() to coerce struct dicts and legacy Mask objects alike - RegionPropsNodes and MaskDiskAttrs register the struct dtype and write struct values - consumers (GraphArrayView, IoUEdgeAttr, MaskMatching, ctc metrics, compute_overlaps, to_geff) materialize masks via as_mask(), so legacy pl.Object mask attributes keep working - ctc metrics skip the pickle-to-bytes multiprocessing shim for struct mask columns, which are Arrow-native Co-Authored-By: Claude Fable 5 --- src/tracksdata/array/_graph_array.py | 7 +- src/tracksdata/edges/_iou_edges.py | 9 +- src/tracksdata/edges/_test/test_iou_edges.py | 5 +- .../functional/_test/test_napari.py | 4 +- src/tracksdata/graph/_base_graph.py | 8 +- .../graph/_test/test_graph_backends.py | 8 +- src/tracksdata/metrics/_ctc_metrics.py | 4 +- src/tracksdata/metrics/_matching.py | 13 +- src/tracksdata/nodes/__init__.py | 4 +- src/tracksdata/nodes/_mask.py | 148 ++++++++++++++++-- src/tracksdata/nodes/_regionprops.py | 5 +- src/tracksdata/nodes/_test/test_mask.py | 86 +++++++++- .../nodes/_test/test_regionprops.py | 8 +- 13 files changed, 265 insertions(+), 44 deletions(-) diff --git a/src/tracksdata/array/_graph_array.py b/src/tracksdata/array/_graph_array.py index 6e88828d..a8eeac8f 100644 --- a/src/tracksdata/array/_graph_array.py +++ b/src/tracksdata/array/_graph_array.py @@ -12,7 +12,7 @@ from tracksdata.utils._dtypes import polars_dtype_to_numpy_dtype if TYPE_CHECKING: - from tracksdata.nodes._mask import Mask + pass def _validate_shape( @@ -346,14 +346,15 @@ def _fill_array(self, time: int, volume_slicing: Sequence[slice], buffer: np.nda np.ndarray The filled buffer. """ + from tracksdata.nodes._mask import as_mask + subgraph = self._spatial_filter[(slice(time, time), *volume_slicing)] df = subgraph.node_attrs( attr_keys=[self._attr_key, DEFAULT_ATTR_KEYS.MASK], ) for mask, value in zip(df[DEFAULT_ATTR_KEYS.MASK], df[self._attr_key], strict=True): - mask: Mask - mask.paint_buffer(buffer, value, offset=self._offset) + as_mask(mask).paint_buffer(buffer, value, offset=self._offset) def _offset_as_array(self, ndim: int) -> np.ndarray: """Normalize `offset` to a vector for each spatial axis.""" diff --git a/src/tracksdata/edges/_iou_edges.py b/src/tracksdata/edges/_iou_edges.py index c70a3492..ce591b13 100644 --- a/src/tracksdata/edges/_iou_edges.py +++ b/src/tracksdata/edges/_iou_edges.py @@ -1,6 +1,11 @@ from tracksdata.constants import DEFAULT_ATTR_KEYS from tracksdata.edges._generic_edges import GenericFuncEdgeAttrs -from tracksdata.nodes._mask import Mask +from tracksdata.nodes._mask import Mask, as_mask + + +def _mask_iou(source_mask: "Mask | dict", target_mask: "Mask | dict") -> float: + """IoU between two mask attribute values (struct dicts or `Mask` instances).""" + return as_mask(source_mask).iou(as_mask(target_mask)) class IoUEdgeAttr(GenericFuncEdgeAttrs): @@ -22,7 +27,7 @@ def __init__( mask_key: str = DEFAULT_ATTR_KEYS.MASK, ): super().__init__( - func=Mask.iou, + func=_mask_iou, attr_keys=mask_key, output_key=output_key, ) diff --git a/src/tracksdata/edges/_test/test_iou_edges.py b/src/tracksdata/edges/_test/test_iou_edges.py index c9ecdf6d..7d5474b2 100644 --- a/src/tracksdata/edges/_test/test_iou_edges.py +++ b/src/tracksdata/edges/_test/test_iou_edges.py @@ -4,6 +4,7 @@ from tracksdata.constants import DEFAULT_ATTR_KEYS from tracksdata.edges import IoUEdgeAttr +from tracksdata.edges._iou_edges import _mask_iou from tracksdata.graph import RustWorkXGraph from tracksdata.nodes import Mask from tracksdata.options import get_options, options_context @@ -15,7 +16,7 @@ def test_iou_edges_init_default() -> None: assert operator.output_key == "iou_score" assert operator.attr_keys == DEFAULT_ATTR_KEYS.MASK - assert operator.func == Mask.iou + assert operator.func == _mask_iou def test_iou_edges_init_custom() -> None: @@ -24,7 +25,7 @@ def test_iou_edges_init_custom() -> None: assert operator.output_key == "custom_iou" assert operator.attr_keys == "custom_mask" - assert operator.func == Mask.iou + assert operator.func == _mask_iou @pytest.mark.parametrize("n_workers", [1, 2]) diff --git a/src/tracksdata/functional/_test/test_napari.py b/src/tracksdata/functional/_test/test_napari.py index 712cf53d..cdc996c5 100644 --- a/src/tracksdata/functional/_test/test_napari.py +++ b/src/tracksdata/functional/_test/test_napari.py @@ -5,7 +5,7 @@ from tracksdata.constants import DEFAULT_ATTR_KEYS from tracksdata.functional import to_napari_format from tracksdata.graph import RustWorkXGraph -from tracksdata.nodes import MaskDiskAttrs +from tracksdata.nodes import MaskDiskAttrs, as_mask @pytest.mark.parametrize("metadata_shape", [True, False]) @@ -47,7 +47,7 @@ def test_napari_conversion(metadata_shape: bool) -> None: graph.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, dtype=pl.Array(pl.Int64, 6)) masks = graph.node_attrs(attr_keys=[DEFAULT_ATTR_KEYS.MASK])[DEFAULT_ATTR_KEYS.MASK] graph.update_node_attrs( - attrs={DEFAULT_ATTR_KEYS.BBOX: [mask.bbox for mask in masks]}, + attrs={DEFAULT_ATTR_KEYS.BBOX: [as_mask(mask).bbox for mask in masks]}, node_ids=graph.node_ids(), ) diff --git a/src/tracksdata/graph/_base_graph.py b/src/tracksdata/graph/_base_graph.py index 6067b71c..7ca40b71 100644 --- a/src/tracksdata/graph/_base_graph.py +++ b/src/tracksdata/graph/_base_graph.py @@ -1326,11 +1326,13 @@ def compute_overlaps(self, iou_threshold: float = 0.0) -> None: raise ValueError("iou_threshold must be between 0.0 and 1.0") def _estimate_overlaps(t: int) -> list[list[int, 2]]: + from tracksdata.nodes._mask import as_mask + node_attrs = self.filter(NodeAttr(DEFAULT_ATTR_KEYS.T) == t).node_attrs( attr_keys=[DEFAULT_ATTR_KEYS.NODE_ID, DEFAULT_ATTR_KEYS.MASK], ) node_ids = node_attrs[DEFAULT_ATTR_KEYS.NODE_ID].to_list() - masks = node_attrs[DEFAULT_ATTR_KEYS.MASK].to_list() + masks = [as_mask(m) for m in node_attrs[DEFAULT_ATTR_KEYS.MASK].to_list()] overlaps = [] for i in range(len(masks)): mask_i = masks[i] @@ -1887,8 +1889,10 @@ def to_geff( } if DEFAULT_ATTR_KEYS.MASK in node_attrs.columns: + from tracksdata.nodes._mask import as_mask + node_dict[DEFAULT_ATTR_KEYS.MASK] = construct_var_len_props( - [mask.mask.astype(np.uint64) for mask in node_attrs[DEFAULT_ATTR_KEYS.MASK]] + [as_mask(mask).mask.astype(np.uint64) for mask in node_attrs[DEFAULT_ATTR_KEYS.MASK]] ) edge_dict = {k: {"values": column_to_numpy(v), "missing": None} for k, v in edge_attrs.to_dict().items()} diff --git a/src/tracksdata/graph/_test/test_graph_backends.py b/src/tracksdata/graph/_test/test_graph_backends.py index 69e9bc09..09db86d4 100644 --- a/src/tracksdata/graph/_test/test_graph_backends.py +++ b/src/tracksdata/graph/_test/test_graph_backends.py @@ -16,7 +16,7 @@ from tracksdata.graph import BaseGraph, IndexedRXGraph, RustWorkXGraph, SQLGraph from tracksdata.io._numpy_array import from_array from tracksdata.nodes import RegionPropsNodes -from tracksdata.nodes._mask import Mask +from tracksdata.nodes._mask import Mask, as_mask def test_already_existing_keys(graph_backend: BaseGraph) -> None: @@ -1621,10 +1621,8 @@ def build_node_map(graph: BaseGraph) -> dict[tuple[int, tuple[int, ...]], dict[s assert target_row["y"] == pytest.approx(source_row["y"]) assert target_row["x"] == pytest.approx(source_row["x"]) - source_mask = source_row[DEFAULT_ATTR_KEYS.MASK] - target_mask = target_row[DEFAULT_ATTR_KEYS.MASK] - assert isinstance(source_mask, Mask) - assert isinstance(target_mask, Mask) + source_mask = as_mask(source_row[DEFAULT_ATTR_KEYS.MASK]) + target_mask = as_mask(target_row[DEFAULT_ATTR_KEYS.MASK]) np.testing.assert_array_equal(source_mask.mask, target_mask.mask) np.testing.assert_array_equal(source_mask.bbox, target_mask.bbox) diff --git a/src/tracksdata/metrics/_ctc_metrics.py b/src/tracksdata/metrics/_ctc_metrics.py index 5bd3e84f..2fe667ef 100644 --- a/src/tracksdata/metrics/_ctc_metrics.py +++ b/src/tracksdata/metrics/_ctc_metrics.py @@ -166,8 +166,8 @@ def _matching_data( ]: attr_keys = [DEFAULT_ATTR_KEYS.T, tracklet_id_key, *required_attrs] nodes_df = graph.node_attrs(attr_keys=attr_keys) - if use_mask_serialization: - # required by multiprocessing + if use_mask_serialization and nodes_df[DEFAULT_ATTR_KEYS.MASK].dtype == pl.Object: + # required by multiprocessing; struct mask columns are Arrow-native and need no pickling nodes_df = column_to_bytes(nodes_df, DEFAULT_ATTR_KEYS.MASK) labels = {} diff --git a/src/tracksdata/metrics/_matching.py b/src/tracksdata/metrics/_matching.py index 1dca1aed..2f0962d2 100644 --- a/src/tracksdata/metrics/_matching.py +++ b/src/tracksdata/metrics/_matching.py @@ -144,6 +144,7 @@ def compute_weights( tuple[list[int], list[int], list[int], list[int], list[float]] Matching data: mapped_ref, mapped_comp, rows, cols, weights (IoU values). """ + from tracksdata.nodes._mask import as_mask from tracksdata.utils._dtypes import column_from_bytes # Handle serialized masks if needed @@ -151,18 +152,18 @@ def compute_weights( ref_group = column_from_bytes(ref_group, DEFAULT_ATTR_KEYS.MASK) comp_group = column_from_bytes(comp_group, DEFAULT_ATTR_KEYS.MASK) + # Materialize masks once, struct values decompress on conversion + ref_masks = [as_mask(m) for m in ref_group[DEFAULT_ATTR_KEYS.MASK]] + comp_masks = [as_mask(m) for m in comp_group[DEFAULT_ATTR_KEYS.MASK]] + mapped_ref = [] mapped_comp = [] rows = [] cols = [] weights = [] - for i, (ref_id, ref_mask) in enumerate( - zip(ref_group[reference_graph_key], ref_group[DEFAULT_ATTR_KEYS.MASK], strict=True) - ): - for j, (comp_id, comp_mask) in enumerate( - zip(comp_group[input_graph_key], comp_group[DEFAULT_ATTR_KEYS.MASK], strict=True) - ): + for i, (ref_id, ref_mask) in enumerate(zip(ref_group[reference_graph_key], ref_masks, strict=True)): + for j, (comp_id, comp_mask) in enumerate(zip(comp_group[input_graph_key], comp_masks, strict=True)): # Intersection over reference is used to select the matches inter = ref_mask.intersection(comp_mask) ctc_score = inter / ref_mask.size diff --git a/src/tracksdata/nodes/__init__.py b/src/tracksdata/nodes/__init__.py index b031d4bc..1fb6f4df 100644 --- a/src/tracksdata/nodes/__init__.py +++ b/src/tracksdata/nodes/__init__.py @@ -1,8 +1,8 @@ """Node operators for creating nodes and their respective attributes (e.g. masks) in a graph.""" from tracksdata.nodes._generic_nodes import GenericFuncNodeAttrs -from tracksdata.nodes._mask import Mask, MaskDiskAttrs +from tracksdata.nodes._mask import Mask, MaskDiskAttrs, as_mask from tracksdata.nodes._random import RandomNodes from tracksdata.nodes._regionprops import RegionPropsNodes -__all__ = ["GenericFuncNodeAttrs", "Mask", "MaskDiskAttrs", "RandomNodes", "RegionPropsNodes"] +__all__ = ["GenericFuncNodeAttrs", "Mask", "MaskDiskAttrs", "RandomNodes", "RegionPropsNodes", "as_mask"] diff --git a/src/tracksdata/nodes/_mask.py b/src/tracksdata/nodes/_mask.py index 22a7619f..2abf0662 100644 --- a/src/tracksdata/nodes/_mask.py +++ b/src/tracksdata/nodes/_mask.py @@ -19,6 +19,28 @@ from tracksdata.graph._base_graph import BaseGraph +def _pack_mask_array(mask: NDArray) -> bytes: + """Compress a mask array into a blosc2 cframe.""" + mask = np.ascontiguousarray(mask) + prev_nthreads = blosc2.set_nthreads(1) + # Bypass blosc2 printing overhead by directly creating a schunk and converting it to cframe, + # instead of using blosc2.pack_tensor + schunk = blosc2.SChunk(data=mask) + dtype = mask.dtype.descr if mask.dtype.kind == "V" else mask.dtype.str + schunk.vlmeta["__pack_tensor__"] = ("numpy", mask.shape, dtype) + cframe = schunk.to_cframe() + blosc2.set_nthreads(prev_nthreads) + return cframe + + +def _unpack_mask_array(data: bytes) -> NDArray: + """Decompress a blosc2 cframe into a mask array.""" + prev_nthreads = blosc2.set_nthreads(1) + mask = blosc2.unpack_tensor(data) + blosc2.set_nthreads(prev_nthreads) + return mask + + @lru_cache(maxsize=5) def _nd_sphere( radius: int, @@ -66,20 +88,11 @@ def __init__( def __getstate__(self) -> dict: data_dict = self.__dict__.copy() - prev_nthreads = blosc2.set_nthreads(1) - # Bypass blosc2 printing overhead by directly creating a schunk and converting it to cframe, - # instead of using blosc2.pack_tensor - schunk = blosc2.SChunk(data=self._mask) - dtype = self._mask.dtype.descr if self._mask.dtype.kind == "V" else self._mask.dtype.str - schunk.vlmeta["__pack_tensor__"] = ("numpy", self._mask.shape, dtype) - data_dict["_mask"] = schunk.to_cframe() - blosc2.set_nthreads(prev_nthreads) + data_dict["_mask"] = _pack_mask_array(self._mask) return data_dict def __setstate__(self, state: dict) -> None: - prev_nthreads = blosc2.set_nthreads(1) - state["_mask"] = blosc2.unpack_tensor(state["_mask"]) - blosc2.set_nthreads(prev_nthreads) + state["_mask"] = _unpack_mask_array(state["_mask"]) self.__dict__.update(state) @property @@ -497,6 +510,115 @@ def __eq__(self, other: Any) -> bool: return False return np.array_equal(self.bbox, other.bbox) and np.array_equal(self.mask, other.mask) + MASK_DATA_FIELD = "data" + + @staticmethod + def bbox_struct_fields(ndim: int) -> list[str]: + """ + Names of the bounding box fields of the mask struct attribute. + + Fields follow the ``bbox`` layout (start indices then end indices), + named after the (z), y, x axis convention, e.g. for 2D: + ``["min_y", "min_x", "max_y", "max_x"]``. + + Parameters + ---------- + ndim : int + The number of spatial dimensions (2 or 3). + + Returns + ------- + list[str] + The bounding box field names. + """ + if ndim < 1 or ndim > 3: + raise ValueError(f"Mask struct attributes are only supported for 1D to 3D masks, got ndim={ndim}") + axes = "zyx"[-ndim:] + return [f"min_{a}" for a in axes] + [f"max_{a}" for a in axes] + + @staticmethod + def struct_dtype(ndim: int) -> pl.Struct: + """ + Polars struct dtype used to store a `Mask` as a struct attribute. + + Bounding box coordinates are scalar integer fields so backends can + filter on them natively (e.g. `NodeAttr("mask").struct.field("min_y") > 5`), + while the binary mask is stored blosc2-compressed in the ``data`` field. + + Parameters + ---------- + ndim : int + The number of spatial dimensions (2 or 3). + + Returns + ------- + pl.Struct + The struct dtype, e.g. for 2D: + `pl.Struct({"min_y": Int64, "min_x": Int64, "max_y": Int64, "max_x": Int64, "data": Binary})`. + """ + fields = dict.fromkeys(Mask.bbox_struct_fields(ndim), pl.Int64) + fields[Mask.MASK_DATA_FIELD] = pl.Binary + return pl.Struct(fields) + + def to_struct(self) -> dict[str, Any]: + """ + Convert the mask to a dict matching [struct_dtype][tracksdata.nodes.Mask.struct_dtype]. + + Returns + ------- + dict[str, Any] + Scalar bounding box fields plus the blosc2-compressed mask under ``"data"``. + """ + fields = self.bbox_struct_fields(self._mask.ndim) + value: dict[str, Any] = {f: int(b) for f, b in zip(fields, self._bbox, strict=True)} + value[self.MASK_DATA_FIELD] = _pack_mask_array(self._mask) + return value + + @classmethod + def from_struct(cls, value: dict[str, Any]) -> "Mask": + """ + Reconstruct a mask from a struct attribute value. + + Parameters + ---------- + value : dict[str, Any] + A dict as produced by [to_struct][tracksdata.nodes.Mask.to_struct]. + + Returns + ------- + Mask + The reconstructed mask. + """ + mask = _unpack_mask_array(value[cls.MASK_DATA_FIELD]) + fields = cls.bbox_struct_fields(mask.ndim) + bbox = np.asarray([value[f] for f in fields], dtype=np.int64) + return cls(mask, bbox) + + +def as_mask(value: "Mask | dict[str, Any]") -> Mask: + """ + Coerce a mask attribute value to a `Mask` instance. + + Accepts both the struct representation (dicts, as returned by graph + backends for struct mask attributes) and `Mask` instances + (legacy `pl.Object` mask attributes). + + Parameters + ---------- + value : Mask | dict[str, Any] + The mask attribute value. + + Returns + ------- + Mask + The coerced mask. + """ + if isinstance(value, Mask): + return value + if isinstance(value, dict): + return Mask.from_struct(value) + raise TypeError(f"Cannot interpret {type(value)} as a Mask.") + class MaskDiskAttrs(GenericFuncNodeAttrs): """ @@ -540,7 +662,7 @@ def __init__( center=np.asarray(list(kwargs.values())), radius=radius, image_shape=image_shape, - ), + ).to_struct(), output_key=output_key, attr_keys=attr_keys, batch_size=0, @@ -551,4 +673,4 @@ def _init_node_attrs(self, graph: "BaseGraph") -> None: Validate that the output key exists in the graph. """ if self.output_key not in graph.node_attr_keys(): - graph.add_node_attr_key(self.output_key, pl.Object) + graph.add_node_attr_key(self.output_key, Mask.struct_dtype(len(self._image_shape))) diff --git a/src/tracksdata/nodes/_regionprops.py b/src/tracksdata/nodes/_regionprops.py index 11ad7261..c53e9ac0 100644 --- a/src/tracksdata/nodes/_regionprops.py +++ b/src/tracksdata/nodes/_regionprops.py @@ -139,6 +139,9 @@ def _init_node_attrs(self, graph: BaseGraph, node_attrs: dict[str, Any]) -> None elif np.isscalar(value): dtype = numpy_char_code_to_dtype(value.dtype) if hasattr(value, "dtype") else type(value) graph.add_node_attr_key(key, dtype) + elif isinstance(value, dict): + # struct-valued attributes, e.g. masks stored as `Mask.to_struct()` + graph.add_node_attr_key(key, pl.Series([value]).dtype) elif type(value).__module__ != "builtins": graph.add_node_attr_key(key, pl.Object) else: @@ -306,7 +309,7 @@ def _nodes_per_time( else: attrs[prop] = getattr(obj, prop) - attrs[DEFAULT_ATTR_KEYS.MASK] = Mask(obj.image, obj.bbox) + attrs[DEFAULT_ATTR_KEYS.MASK] = Mask(obj.image, obj.bbox).to_struct() attrs[DEFAULT_ATTR_KEYS.BBOX] = np.asarray(obj.bbox, dtype=int) attrs[DEFAULT_ATTR_KEYS.T] = t diff --git a/src/tracksdata/nodes/_test/test_mask.py b/src/tracksdata/nodes/_test/test_mask.py index 89f9637f..dc017e25 100644 --- a/src/tracksdata/nodes/_test/test_mask.py +++ b/src/tracksdata/nodes/_test/test_mask.py @@ -1,7 +1,7 @@ import numpy as np import pytest -from tracksdata.nodes._mask import Mask, _nd_sphere +from tracksdata.nodes._mask import Mask, _nd_sphere, as_mask def test_mask_init() -> None: @@ -601,3 +601,87 @@ def test_mask_move() -> None: mask.move(offset=np.asarray([-3, 2]), image_shape=(7, 7)) np.testing.assert_array_equal(mask.bbox, [2, 4, 3, 5]) np.testing.assert_array_equal(mask.mask, point) + + +def test_mask_struct_dtype() -> None: + import polars as pl + + dtype_2d = Mask.struct_dtype(2) + assert dtype_2d == pl.Struct( + { + "min_y": pl.Int64, + "min_x": pl.Int64, + "max_y": pl.Int64, + "max_x": pl.Int64, + "data": pl.Binary, + } + ) + + dtype_3d = Mask.struct_dtype(3) + assert [f.name for f in dtype_3d.fields] == [ + "min_z", + "min_y", + "min_x", + "max_z", + "max_y", + "max_x", + "data", + ] + + with pytest.raises(ValueError): + Mask.struct_dtype(4) + + +@pytest.mark.parametrize("ndim", [2, 3]) +def test_mask_struct_roundtrip(ndim: int) -> None: + rng = np.random.default_rng(0) + shape = (3, 4, 5)[:ndim] + mask_data = rng.uniform(size=shape) > 0.5 + bbox = np.concatenate([np.arange(1, ndim + 1), np.arange(1, ndim + 1) + shape]) + + mask = Mask(mask_data, bbox=bbox) + value = mask.to_struct() + + assert isinstance(value, dict) + assert isinstance(value["data"], bytes) + assert value["min_y" if ndim == 2 else "min_z"] == 1 + + restored = Mask.from_struct(value) + assert restored == mask + + +def test_as_mask() -> None: + mask = Mask(np.ones((2, 2), dtype=bool), bbox=np.array([0, 0, 2, 2])) + + assert as_mask(mask) is mask + assert as_mask(mask.to_struct()) == mask + + with pytest.raises(TypeError): + as_mask("not a mask") + + +def test_mask_struct_attr_in_graph(graph_backend) -> None: + """Masks stored as struct attributes round-trip and are filterable by bbox fields.""" + + from tracksdata.attrs import NodeAttr + from tracksdata.constants import DEFAULT_ATTR_KEYS + + graph = graph_backend + graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, Mask.struct_dtype(2)) + + mask_a = Mask(np.ones((2, 2), dtype=bool), bbox=np.array([0, 0, 2, 2])) + mask_b = Mask(np.ones((2, 3), dtype=bool), bbox=np.array([5, 6, 7, 9])) + + node_a = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, DEFAULT_ATTR_KEYS.MASK: mask_a.to_struct()}) + node_b = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, DEFAULT_ATTR_KEYS.MASK: mask_b.to_struct()}) + + df = graph.node_attrs(attr_keys=[DEFAULT_ATTR_KEYS.NODE_ID, DEFAULT_ATTR_KEYS.MASK]) + assert df[DEFAULT_ATTR_KEYS.MASK].dtype == Mask.struct_dtype(2) + + restored = {n: as_mask(v) for n, v in zip(df[DEFAULT_ATTR_KEYS.NODE_ID], df[DEFAULT_ATTR_KEYS.MASK], strict=True)} + assert restored[node_a] == mask_a + assert restored[node_b] == mask_b + + # filtering on a bbox field of the mask struct + filtered = graph.filter(NodeAttr(DEFAULT_ATTR_KEYS.MASK).struct.field("min_y") > 2).node_ids() + assert filtered == [node_b] diff --git a/src/tracksdata/nodes/_test/test_regionprops.py b/src/tracksdata/nodes/_test/test_regionprops.py index 567c62e0..9e8a3d5b 100644 --- a/src/tracksdata/nodes/_test/test_regionprops.py +++ b/src/tracksdata/nodes/_test/test_regionprops.py @@ -4,7 +4,7 @@ from tracksdata.constants import DEFAULT_ATTR_KEYS from tracksdata.graph import RustWorkXGraph -from tracksdata.nodes import Mask, RegionPropsNodes +from tracksdata.nodes import Mask, RegionPropsNodes, as_mask from tracksdata.options import get_options, options_context @@ -278,12 +278,14 @@ def test_regionprops_mask_creation() -> None: assert "shape" in graph.metadata assert graph.metadata["shape"] == labels.shape - # Check that masks were created + # Check that masks were created as struct attributes nodes_df = graph.node_attrs() masks = nodes_df[DEFAULT_ATTR_KEYS.MASK] + assert masks.dtype == Mask.struct_dtype(labels.ndim - 1) - # All masks should be Mask objects + # All masks should convert back to Mask objects for mask in masks: + mask = as_mask(mask) assert isinstance(mask, Mask) assert mask._mask is not None assert mask._bbox is not None From 7d02b80ef5908281f96b0dca37954a97ac4849a6 Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Tue, 16 Jun 2026 18:02:10 -0700 Subject: [PATCH 29/30] Store binary attributes as raw BLOB in SQL instead of pickling The Mask struct's `data` leaf (blosc2-compressed bytes) was stored through a SQLAlchemy PickleType column, wrapping the already-compressed bytes in a second pickle layer on every write and unpickling on every read. Map `pl.Binary` to `sa.LargeBinary` so binary bytes are stored as a raw BLOB. Pickling is now detected from the actual SQL column type (PickleType only, not LargeBinary): - `_is_pickled_sql_type` returns True only for PickleType columns. - `unpickle_bytes_columns` takes the explicit set of pickled physical columns so raw-binary columns (e.g. the Mask `data` leaf) are left untouched. - `_restore_pickled_column_types` re-tags reflected LargeBinary columns as PickleType except genuine raw-binary columns (schema dtype pl.Binary), which reflection cannot otherwise distinguish from pickled blobs. Reordered `_define_schema` so the attribute schemas are available when this runs. Adds a regression test asserting the Mask `data` leaf is a raw LargeBinary column (not PickleType) before and after reload, and that the mask round-trips and struct-field filtering still works. Co-Authored-By: Claude Fable 5 --- src/tracksdata/graph/_sql_graph.py | 70 +++++++++++++++---- .../graph/_test/test_graph_backends.py | 37 ++++++++++ src/tracksdata/utils/_dataframe.py | 26 +++++-- src/tracksdata/utils/_dtypes.py | 3 + src/tracksdata/utils/_test/test_dataframe.py | 2 +- 5 files changed, 119 insertions(+), 19 deletions(-) diff --git a/src/tracksdata/graph/_sql_graph.py b/src/tracksdata/graph/_sql_graph.py index 2c703c15..70fee933 100644 --- a/src/tracksdata/graph/_sql_graph.py +++ b/src/tracksdata/graph/_sql_graph.py @@ -422,7 +422,7 @@ def _read_attr_dataframe(self, query: sa.Select, table: type[DeclarativeBase]) - schema_overrides=self._graph._polars_schema_override(table), ) - df = unpickle_bytes_columns(df) + df = unpickle_bytes_columns(df, self._graph._pickled_physical_columns(table)) return self._graph._cast_columns(table, df) def _query_from_attr_keys( @@ -656,6 +656,11 @@ def __init__( self._engine_kwargs = engine_kwargs if engine_kwargs is not None else {} self._engine = sa.create_engine(self._url, **self._engine_kwargs) + # Initialized before `_define_schema`, which (when reloading an existing + # database) reads the attribute schemas to restore pickle column types. + self._node_attr_schemas_cache: dict | None = None + self._edge_attr_schemas_cache: dict | None = None + # Create unique classes for this instance self._define_schema(overwrite=overwrite) @@ -666,8 +671,6 @@ def __init__( self._max_id_per_time = {} self._update_max_id_per_time() - self._node_attr_schemas_cache: dict | None = None - self._edge_attr_schemas_cache: dict | None = None def supports_custom_indices(self) -> bool: return True @@ -686,8 +689,6 @@ class Base(DeclarativeBase): pass if len(metadata.tables) > 0 and not overwrite: - for table in metadata.tables.values(): - self._restore_pickled_column_types(table) for table_name, table in metadata.tables.items(): cls = type( table_name, @@ -699,6 +700,10 @@ class Base(DeclarativeBase): ) setattr(self, table_name, cls) self.Base = Base + # Restore pickle column types only after the ORM classes exist, so + # `_pickled_physical_columns` can read the stored attribute schemas. + for table_name in metadata.tables: + self._restore_pickled_column_types(getattr(self, table_name)) return class Node(Base): @@ -794,7 +799,37 @@ def _attr_schemas_for_table(self, table_class: type[DeclarativeBase]) -> dict[st @staticmethod def _is_pickled_sql_type(column_type: TypeEngine) -> bool: - return isinstance(column_type, sa.PickleType | sa.LargeBinary) + """Whether a SQL column stores pickled Python objects. + + Only ``PickleType`` columns are pickled. Plain ``LargeBinary`` columns + hold raw bytes (e.g. the blosc2-compressed Mask ``data`` leaf) and must + not be unpickled. After reflection both report as ``LargeBinary``, so + :meth:`_restore_pickled_column_types` re-tags the genuinely-pickled ones + as ``PickleType`` before this check is used. + """ + return isinstance(column_type, sa.PickleType) + + def _pickled_physical_columns(self, table_class: type[DeclarativeBase]) -> list[str]: + """Physical column names whose values are pickled (vs stored natively).""" + return [col.name for col in table_class.__table__.columns if self._is_pickled_sql_type(col.type)] + + def _raw_binary_physical_columns(self, table_class: type[DeclarativeBase]) -> set[str]: + """Physical column names that hold raw ``pl.Binary`` bytes (never pickled). + + These are the only blob columns left as ``LargeBinary`` after reflection; + every other blob column is re-tagged as ``PickleType``. + """ + raw: set[str] = set() + for key, schema in self._attr_schemas_for_table(table_class).items(): + if isinstance(schema.dtype, pl.Struct): + raw.update( + flat_col + for flat_col, leaf_dtype in flatten_struct_dtype(key, schema.dtype) + if leaf_dtype == pl.Binary + ) + elif schema.dtype == pl.Binary: + raw.add(key) + return raw @property def __node_attr_schemas(self) -> dict[str, AttrSchema]: @@ -840,9 +875,20 @@ def __edge_attr_schemas(self, schemas: dict[str, AttrSchema]) -> None: self._private_metadata[self._PRIVATE_SQL_EDGE_SCHEMA_STORE_KEY] = encoded_schemas self._edge_attr_schemas_cache = None - def _restore_pickled_column_types(self, table: sa.Table) -> None: - for column in table.columns: - if isinstance(column.type, sa.LargeBinary): + def _restore_pickled_column_types(self, table_class: type[DeclarativeBase]) -> None: + """Restore ``PickleType`` on reflected pickle columns. + + Reflection reports every blob column as ``LargeBinary``, losing the + distinction between genuinely-pickled columns and raw-binary ones. We + consult the stored schema (via :meth:`_pickled_physical_columns`) and + only re-tag the pickled ones, leaving raw-binary columns (e.g. the Mask + ``data`` leaf) as ``LargeBinary`` so writes store their bytes directly. + """ + if table_class.__tablename__ not in (self.Node.__tablename__, self.Edge.__tablename__): + return + raw_binary = self._raw_binary_physical_columns(table_class) + for column in table_class.__table__.columns: + if isinstance(column.type, sa.LargeBinary) and column.name not in raw_binary: column.type = sa.PickleType() def _polars_schema_override(self, table_class: type[DeclarativeBase]) -> SchemaDict: @@ -1357,7 +1403,7 @@ def _get_neighbors( filter_node_ids, self.Node, ) - node_df = unpickle_bytes_columns(node_df) + node_df = unpickle_bytes_columns(node_df, self._pickled_physical_columns(self.Node)) node_df = self._cast_columns(self.Node, node_df) if single_node: @@ -1550,7 +1596,7 @@ def node_attrs( connection=session.connection(), schema_overrides=self._polars_schema_override(self.Node), ) - nodes_df = unpickle_bytes_columns(nodes_df) + nodes_df = unpickle_bytes_columns(nodes_df, self._pickled_physical_columns(self.Node)) nodes_df = self._cast_columns(self.Node, nodes_df) # Select using logical keys (struct columns are now reconstructed). @@ -1596,7 +1642,7 @@ def edge_attrs( connection=session.connection(), schema_overrides=self._polars_schema_override(self.Edge), ) - edges_df = unpickle_bytes_columns(edges_df) + edges_df = unpickle_bytes_columns(edges_df, self._pickled_physical_columns(self.Edge)) edges_df = self._cast_columns(self.Edge, edges_df) if unpack: diff --git a/src/tracksdata/graph/_test/test_graph_backends.py b/src/tracksdata/graph/_test/test_graph_backends.py index 17ec2f1f..65f54cac 100644 --- a/src/tracksdata/graph/_test/test_graph_backends.py +++ b/src/tracksdata/graph/_test/test_graph_backends.py @@ -17,6 +17,7 @@ from tracksdata.io._numpy_array import from_array from tracksdata.nodes import RegionPropsNodes from tracksdata.nodes._mask import Mask, as_mask +from tracksdata.utils._dtypes import STRUCT_FIELD_SEP def test_already_existing_keys(graph_backend: BaseGraph) -> None: @@ -1761,6 +1762,42 @@ def test_sql_graph_mask_update_survives_reload(tmp_path: Path) -> None: np.testing.assert_array_equal(stored_mask.mask, mask_data) +def test_sql_graph_mask_struct_stored_raw_not_pickled(tmp_path: Path) -> None: + """Mask struct's binary `data` leaf is stored as raw BLOB, not double-pickled. + + Regression for storing the (already blosc2-compressed) mask bytes through a + SQLAlchemy ``PickleType`` column, which wrapped them in a second pickle layer. + """ + db_path = tmp_path / "mask_struct.db" + graph = SQLGraph("sqlite", str(db_path)) + graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, Mask.struct_dtype(2)) + + mask_data = np.array([[True, False], [False, True]], dtype=bool) + mask = Mask(mask_data, bbox=np.array([0, 0, 2, 2])) + node_id = graph.add_node({"t": 0, DEFAULT_ATTR_KEYS.MASK: mask.to_struct()}) + + data_col = f"{DEFAULT_ATTR_KEYS.MASK}{STRUCT_FIELD_SEP}{Mask.MASK_DATA_FIELD}" + # The binary leaf must be a raw LargeBinary column, never a PickleType. + assert isinstance(graph.Node.__table__.columns[data_col].type, sa.LargeBinary) + assert not isinstance(graph.Node.__table__.columns[data_col].type, sa.PickleType) + + graph._engine.dispose() + + reloaded = SQLGraph("sqlite", str(db_path)) + # Still raw binary after reflection — not restored to PickleType. + assert isinstance(reloaded.Node.__table__.columns[data_col].type, sa.LargeBinary) + assert not isinstance(reloaded.Node.__table__.columns[data_col].type, sa.PickleType) + + df = reloaded.node_attrs(attr_keys=[DEFAULT_ATTR_KEYS.MASK]) + assert df.schema[DEFAULT_ATTR_KEYS.MASK] == Mask.struct_dtype(2) + restored = as_mask(df[DEFAULT_ATTR_KEYS.MASK].to_list()[0]) + np.testing.assert_array_equal(restored.mask, mask_data) + np.testing.assert_array_equal(restored.bbox, mask.bbox) + + # Struct-field filtering still works against the flat physical bbox columns. + assert reloaded.filter(NodeAttr(DEFAULT_ATTR_KEYS.MASK).struct.field("min_y") == 0).node_ids() == [node_id] + + def test_sql_graph_struct_dtype_survives_reload(tmp_path: Path) -> None: db_path = tmp_path / "struct_graph.db" graph = SQLGraph("sqlite", str(db_path)) diff --git a/src/tracksdata/utils/_dataframe.py b/src/tracksdata/utils/_dataframe.py index a6de0f17..7b28d6e1 100644 --- a/src/tracksdata/utils/_dataframe.py +++ b/src/tracksdata/utils/_dataframe.py @@ -1,3 +1,5 @@ +from collections.abc import Collection + import cloudpickle import polars as pl import polars.selectors as cs @@ -29,23 +31,35 @@ def unpack_array_attrs(df: pl.DataFrame) -> pl.DataFrame: return unpack_array_attrs(df) -def unpickle_bytes_columns(df: pl.DataFrame) -> pl.DataFrame: +def unpickle_bytes_columns(df: pl.DataFrame, columns: Collection[str]) -> pl.DataFrame: """ - Unpickle bytes columns from the database. + Unpickle pickled bytes columns read from the database. + + Only the columns in *columns* are unpickled. Raw-binary columns (e.g. the + blosc2-compressed ``data`` leaf of a Mask struct attribute) are stored + natively as ``pl.Binary`` and must be left untouched, so callers pass the + explicit set of genuinely-pickled physical columns rather than relying on + all binary columns being pickled. Parameters ---------- df : pl.DataFrame The DataFrame to unpickle the bytes columns from. + columns : Collection[str] + The physical column names that hold pickled values. Returns ------- pl.DataFrame - The DataFrame with the bytes columns unpickled. + The DataFrame with the pickled columns unpickled. """ - df = df.map_columns(cs.binary(), lambda x: x.map_elements(cloudpickle.loads, return_dtype=pl.Object)) - for col, dtype in zip(df.columns, df.dtypes, strict=True): - if isinstance(dtype, pl.Object): + targets = [col for col in columns if col in df.columns and df.schema[col] == pl.Binary] + if not targets: + return df + + df = df.map_columns(cs.by_name(targets), lambda x: x.map_elements(cloudpickle.loads, return_dtype=pl.Object)) + for col in targets: + if isinstance(df.schema[col], pl.Object): try: df = df.with_columns(pl.Series(df[col].to_list()).alias(col)) except Exception: diff --git a/src/tracksdata/utils/_dtypes.py b/src/tracksdata/utils/_dtypes.py index 5c25b710..a245eee0 100644 --- a/src/tracksdata/utils/_dtypes.py +++ b/src/tracksdata/utils/_dtypes.py @@ -392,6 +392,9 @@ def infer_default_value_from_dtype(dtype: pl.DataType) -> Any: # String types pl.String: sa.String, pl.Utf8: sa.String, + # Raw binary blobs are stored as-is (no pickle round-trip), e.g. the + # blosc2-compressed `data` leaf of a Mask struct attribute. + pl.Binary: sa.LargeBinary, } diff --git a/src/tracksdata/utils/_test/test_dataframe.py b/src/tracksdata/utils/_test/test_dataframe.py index 0ac92389..8f2fe03e 100644 --- a/src/tracksdata/utils/_test/test_dataframe.py +++ b/src/tracksdata/utils/_test/test_dataframe.py @@ -45,7 +45,7 @@ def test_unpickle_bytes_columns_variable_size_arrays() -> None: arrays = [np.ones((41, 41), dtype=bool), np.ones((4, 4), dtype=bool)] df = pl.DataFrame({"mask": pl.Series([cloudpickle.dumps(a) for a in arrays], dtype=pl.Binary)}) - result = unpickle_bytes_columns(df) # must not raise SchemaError + result = unpickle_bytes_columns(df, ["mask"]) # must not raise SchemaError for actual, expected in zip(result["mask"].to_list(), arrays, strict=False): np.testing.assert_array_equal(actual, expected) From 5e156b9677eb718476861d5586142894ce94a926 Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Tue, 16 Jun 2026 18:07:52 -0700 Subject: [PATCH 30/30] Add user-selectable mask compression codec Introduce `MaskCodec` (BLOSC2, RAW, PACKBITS) so callers can trade encode cost against stored size. The codec is recorded as the first byte of the packed payload, making each mask self-describing: a single column can mix codecs and any encoding stays readable regardless of the current default. - `set_default_mask_codec` / `get_default_mask_codec` expose the default codec used when none is passed; `Mask.to_struct(codec=...)` overrides per call. - RAW and PACKBITS prepend a tiny shape header (ndim + uint32 dims) so unpacking needs no external metadata; BLOSC2 keeps using its self-describing cframe. - RAW decodes as a zero-copy read-only view; `Mask.__isub__` now copies-on-write so in-place difference still works on such masks. Default remains BLOSC2, so existing behavior is unchanged. PACKBITS is typically far smaller for small/medium cell masks (e.g. 26 B vs 347 B for an 11x11 disk). Co-Authored-By: Claude Fable 5 --- src/tracksdata/nodes/__init__.py | 21 +++- src/tracksdata/nodes/_mask.py | 145 ++++++++++++++++++++++-- src/tracksdata/nodes/_test/test_mask.py | 69 ++++++++++- 3 files changed, 224 insertions(+), 11 deletions(-) diff --git a/src/tracksdata/nodes/__init__.py b/src/tracksdata/nodes/__init__.py index 1fb6f4df..3f7f8fe3 100644 --- a/src/tracksdata/nodes/__init__.py +++ b/src/tracksdata/nodes/__init__.py @@ -1,8 +1,25 @@ """Node operators for creating nodes and their respective attributes (e.g. masks) in a graph.""" from tracksdata.nodes._generic_nodes import GenericFuncNodeAttrs -from tracksdata.nodes._mask import Mask, MaskDiskAttrs, as_mask +from tracksdata.nodes._mask import ( + Mask, + MaskCodec, + MaskDiskAttrs, + as_mask, + get_default_mask_codec, + set_default_mask_codec, +) from tracksdata.nodes._random import RandomNodes from tracksdata.nodes._regionprops import RegionPropsNodes -__all__ = ["GenericFuncNodeAttrs", "Mask", "MaskDiskAttrs", "RandomNodes", "RegionPropsNodes", "as_mask"] +__all__ = [ + "GenericFuncNodeAttrs", + "Mask", + "MaskCodec", + "MaskDiskAttrs", + "RandomNodes", + "RegionPropsNodes", + "as_mask", + "get_default_mask_codec", + "set_default_mask_codec", +] diff --git a/src/tracksdata/nodes/_mask.py b/src/tracksdata/nodes/_mask.py index 2abf0662..e9436202 100644 --- a/src/tracksdata/nodes/_mask.py +++ b/src/tracksdata/nodes/_mask.py @@ -1,4 +1,5 @@ from collections.abc import Sequence +from enum import IntEnum from functools import cached_property, lru_cache from typing import TYPE_CHECKING, Any @@ -19,9 +20,80 @@ from tracksdata.graph._base_graph import BaseGraph -def _pack_mask_array(mask: NDArray) -> bytes: - """Compress a mask array into a blosc2 cframe.""" - mask = np.ascontiguousarray(mask) +class MaskCodec(IntEnum): + """Codec used to encode a mask's boolean array into bytes. + + The chosen codec is stored as the first byte of the packed payload, so a + single column can mix codecs and any encoding stays readable regardless of + the current default. + + - ``BLOSC2``: blosc2 cframe. Best for large masks where the ~340 B fixed + header is amortized and RLE compression wins; slowest to encode (~300 µs). + - ``RAW``: the uncompressed boolean buffer. Smallest header; decodes as a + zero-copy view. Best when masks are tiny and incompressible. + - ``PACKBITS``: ``np.packbits`` (8 booleans per byte). Best all-rounder for + small/medium cell masks — far smaller than blosc2 without its overhead. + """ + + BLOSC2 = 0 + RAW = 1 + PACKBITS = 2 + + +# Default codec used when callers do not pass one explicitly. Mutated through +# `set_default_mask_codec` so users can trade encode cost against stored size. +_DEFAULT_MASK_CODEC = MaskCodec.BLOSC2 + + +def get_default_mask_codec() -> MaskCodec: + """Return the codec used to pack masks when none is given explicitly.""" + return _DEFAULT_MASK_CODEC + + +def set_default_mask_codec(codec: "MaskCodec | int | str") -> MaskCodec: + """Set the default codec used to pack masks. + + Parameters + ---------- + codec : MaskCodec | int | str + The codec, as a `MaskCodec`, its integer value, or its name + (e.g. ``"packbits"``, case-insensitive). + + Returns + ------- + MaskCodec + The resolved codec now in effect. + """ + global _DEFAULT_MASK_CODEC + _DEFAULT_MASK_CODEC = _resolve_codec(codec) + return _DEFAULT_MASK_CODEC + + +def _resolve_codec(codec: "MaskCodec | int | str | None") -> MaskCodec: + if codec is None: + return _DEFAULT_MASK_CODEC + if isinstance(codec, str): + try: + return MaskCodec[codec.upper()] + except KeyError: + raise ValueError(f"Unknown mask codec '{codec}'. Options: {[c.name.lower() for c in MaskCodec]}") from None + return MaskCodec(codec) + + +def _encode_shape_header(mask: NDArray) -> bytes: + """Encode ``ndim`` and the shape so non-self-describing codecs can reshape.""" + return bytes([mask.ndim]) + np.asarray(mask.shape, dtype=" tuple[tuple[int, ...], int]: + ndim = data[offset] + offset += 1 + shape = tuple(int(s) for s in np.frombuffer(data, dtype=" bytes: prev_nthreads = blosc2.set_nthreads(1) # Bypass blosc2 printing overhead by directly creating a schunk and converting it to cframe, # instead of using blosc2.pack_tensor @@ -33,14 +105,61 @@ def _pack_mask_array(mask: NDArray) -> bytes: return cframe -def _unpack_mask_array(data: bytes) -> NDArray: - """Decompress a blosc2 cframe into a mask array.""" +def _blosc2_unpack(data: bytes) -> NDArray: prev_nthreads = blosc2.set_nthreads(1) mask = blosc2.unpack_tensor(data) blosc2.set_nthreads(prev_nthreads) return mask +def _pack_mask_array(mask: NDArray, codec: "MaskCodec | int | str | None" = None) -> bytes: + """Encode a mask array into bytes, tagged with the codec used. + + Parameters + ---------- + mask : NDArray + The boolean mask array. + codec : MaskCodec | int | str | None + The codec to use. If None, [get_default_mask_codec][tracksdata.nodes.get_default_mask_codec] is used. + + Returns + ------- + bytes + ``bytes([codec]) + payload``; self-describing so `_unpack_mask_array` + needs no external metadata. + """ + codec = _resolve_codec(codec) + mask = np.ascontiguousarray(mask, dtype=bool) + + if codec == MaskCodec.BLOSC2: + return bytes([codec]) + _blosc2_pack(mask) + + header = _encode_shape_header(mask) + if codec == MaskCodec.RAW: + return bytes([codec]) + header + mask.tobytes() + if codec == MaskCodec.PACKBITS: + return bytes([codec]) + header + np.packbits(mask.reshape(-1)).tobytes() + raise ValueError(f"Unsupported mask codec: {codec}") + + +def _unpack_mask_array(data: bytes) -> NDArray: + """Decode bytes produced by `_pack_mask_array` back into a mask array.""" + codec = MaskCodec(data[0]) + + if codec == MaskCodec.BLOSC2: + return _blosc2_unpack(data[1:]) + + shape, offset = _decode_shape_header(data, 1) + payload = data[offset:] + if codec == MaskCodec.RAW: + return np.frombuffer(payload, dtype=bool).reshape(shape) + if codec == MaskCodec.PACKBITS: + count = int(np.prod(shape)) if shape else 1 + flat = np.unpackbits(np.frombuffer(payload, dtype=np.uint8), count=count) + return flat.astype(bool).reshape(shape) + raise ValueError(f"Unsupported mask codec: {codec}") + + @lru_cache(maxsize=5) def _nd_sphere( radius: int, @@ -312,6 +431,10 @@ def __isub__(self, other: "Mask") -> "Mask": if self.intersection(other) == 0: return self + # `_mask` may be a read-only zero-copy view (RAW codec); copy before mutating in place. + if not self._mask.flags.writeable: + self._mask = self._mask.copy() + other_slicing = [] self_slicing = [] for i in range(self._mask.ndim): @@ -560,18 +683,24 @@ def struct_dtype(ndim: int) -> pl.Struct: fields[Mask.MASK_DATA_FIELD] = pl.Binary return pl.Struct(fields) - def to_struct(self) -> dict[str, Any]: + def to_struct(self, codec: "MaskCodec | int | str | None" = None) -> dict[str, Any]: """ Convert the mask to a dict matching [struct_dtype][tracksdata.nodes.Mask.struct_dtype]. + Parameters + ---------- + codec : MaskCodec | int | str | None + The codec used to encode the mask array. If None, the default set by + [set_default_mask_codec][tracksdata.nodes.set_default_mask_codec] is used. + Returns ------- dict[str, Any] - Scalar bounding box fields plus the blosc2-compressed mask under ``"data"``. + Scalar bounding box fields plus the encoded mask under ``"data"``. """ fields = self.bbox_struct_fields(self._mask.ndim) value: dict[str, Any] = {f: int(b) for f, b in zip(fields, self._bbox, strict=True)} - value[self.MASK_DATA_FIELD] = _pack_mask_array(self._mask) + value[self.MASK_DATA_FIELD] = _pack_mask_array(self._mask, codec) return value @classmethod diff --git a/src/tracksdata/nodes/_test/test_mask.py b/src/tracksdata/nodes/_test/test_mask.py index dc017e25..4e7a2ebe 100644 --- a/src/tracksdata/nodes/_test/test_mask.py +++ b/src/tracksdata/nodes/_test/test_mask.py @@ -1,7 +1,16 @@ import numpy as np import pytest -from tracksdata.nodes._mask import Mask, _nd_sphere, as_mask +from tracksdata.nodes._mask import ( + Mask, + MaskCodec, + _nd_sphere, + _pack_mask_array, + _unpack_mask_array, + as_mask, + get_default_mask_codec, + set_default_mask_codec, +) def test_mask_init() -> None: @@ -685,3 +694,61 @@ def test_mask_struct_attr_in_graph(graph_backend) -> None: # filtering on a bbox field of the mask struct filtered = graph.filter(NodeAttr(DEFAULT_ATTR_KEYS.MASK).struct.field("min_y") > 2).node_ids() assert filtered == [node_b] + + +@pytest.mark.parametrize("codec", list(MaskCodec)) +@pytest.mark.parametrize("ndim", [1, 2, 3]) +def test_pack_unpack_codec_roundtrip(codec: MaskCodec, ndim: int) -> None: + rng = np.random.default_rng(0) + shape = (7, 5, 4)[:ndim] + mask = rng.uniform(size=shape) > 0.5 + + packed = _pack_mask_array(mask, codec) + assert packed[0] == codec # codec is tagged in the first byte + restored = _unpack_mask_array(packed) + + assert restored.shape == mask.shape + np.testing.assert_array_equal(restored, mask) + + +def test_codec_resolution_by_name_and_int() -> None: + mask = np.ones((3, 3), dtype=bool) + assert _unpack_mask_array(_pack_mask_array(mask, "packbits"))[0, 0] + assert _pack_mask_array(mask, "raw")[0] == MaskCodec.RAW + assert _pack_mask_array(mask, int(MaskCodec.BLOSC2))[0] == MaskCodec.BLOSC2 + + with pytest.raises(ValueError): + _pack_mask_array(mask, "nope") + + +def test_default_mask_codec_switch() -> None: + original = get_default_mask_codec() + try: + set_default_mask_codec("packbits") + assert get_default_mask_codec() == MaskCodec.PACKBITS + # to_struct() with no explicit codec uses the default + mask = Mask(np.ones((4, 4), dtype=bool), bbox=np.array([0, 0, 4, 4])) + assert mask.to_struct()["data"][0] == MaskCodec.PACKBITS + # explicit codec overrides the default + assert mask.to_struct(codec="raw")["data"][0] == MaskCodec.RAW + finally: + set_default_mask_codec(original) + + +def test_mixed_codecs_in_one_column_are_readable() -> None: + """A column may hold masks packed with different codecs; each stays readable.""" + mask = np.array([[True, False], [True, True]], dtype=bool) + bbox = np.array([0, 0, 2, 2]) + for codec in MaskCodec: + value = Mask(mask, bbox=bbox).to_struct(codec=codec) + assert as_mask(value) == Mask(mask, bbox=bbox) + + +def test_raw_codec_mask_is_mutable_after_difference() -> None: + """RAW decodes to a read-only view; in-place difference must still work.""" + a = Mask(np.ones((4, 4), dtype=bool), bbox=np.array([0, 0, 4, 4])) + decoded = as_mask(a.to_struct(codec="raw")) + other = Mask(np.ones((2, 2), dtype=bool), bbox=np.array([0, 0, 2, 2])) + decoded -= other # would raise on a read-only array without the copy-on-write guard + assert not decoded.mask[0, 0] + assert decoded.mask[3, 3]