Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
29177f3
added private metadata machinery
yfukai Feb 17, 2026
d8292f1
before adding private
yfukai Feb 17, 2026
cff5898
added private metadata view
yfukai Feb 17, 2026
68b01d4
renamed func
yfukai Feb 17, 2026
1ae2426
implementation of saving and loading dtypes as metadata
yfukai Feb 17, 2026
c50a07b
lint
yfukai Feb 17, 2026
e9bf28f
restricted dtype metadata to sqlgraph
yfukai Feb 18, 2026
9aa9c3a
udpated serialization strategies
yfukai Feb 18, 2026
7e61ac3
solved failing tests
yfukai Feb 18, 2026
e5968bf
added test for shape-less pl.Array (xfail)
yfukai Feb 18, 2026
b4acde3
working
yfukai Feb 19, 2026
cc55976
simplified code
yfukai Feb 19, 2026
e76d8e5
initial try
yfukai Feb 19, 2026
7bec369
saving private metadata
yfukai Feb 20, 2026
852f717
rustworkx reviewed
yfukai Feb 26, 2026
4c151bb
Merge branch 'from_other_roundtrip' into struct_attr
yfukai Feb 26, 2026
4af9904
working with clean code?
yfukai Feb 26, 2026
19055ab
Merge branch 'main' into struct_attr
JoOkuma Feb 27, 2026
6c69e76
updated impl
yfukai Apr 10, 2026
d9bee26
removed codex config wrongly added
yfukai Apr 10, 2026
cc0beb4
issue fixes
yfukai Apr 14, 2026
007d4c7
Merge branch 'main' into struct_attr
yfukai Apr 14, 2026
ffea2ec
rolled back unncessary change
yfukai Apr 14, 2026
0ad6c60
Merge remote-tracking branch 'upstream/main' into struct_attr
yfukai May 28, 2026
5e7331f
additional comments
yfukai May 28, 2026
7e07801
Merge branch 'main' into struct_attr
JoOkuma Jun 1, 2026
37f8fc9
Fix lint: remove whitespace from blank lines
JoOkuma Jun 1, 2026
1842c55
fixes
yfukai Jun 4, 2026
db35287
refactor aligning main
yfukai Jun 4, 2026
3759d9c
Restore scratch-table machinery and tests from main
yfukai Jun 5, 2026
f5b7cc0
Merge branch 'main' of https://github.com/royerlab/tracksdata into st…
yfukai Jun 8, 2026
0d76262
ignored the devcontaienr
yfukai Jun 8, 2026
87707d0
bugfix
yfukai Jun 8, 2026
9ba99e1
Store Mask as a struct attribute instead of pickled pl.Object
yfukai Jun 10, 2026
e7f99af
Merge upstream/main into mask_struct_attr
yfukai Jun 17, 2026
7d02b80
Store binary attributes as raw BLOB in SQL instead of pickling
yfukai Jun 17, 2026
00f41eb
Clean up mask struct-attribute branch for review
yfukai Jun 17, 2026
07b68c6
Document why as_mask imports are function-local
yfukai Jun 17, 2026
1a6ad08
updating name and adding comments
yfukai Jun 22, 2026
a21dd31
adding test
yfukai Jun 22, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions src/tracksdata/array/_graph_array.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections.abc import Sequence
from copy import copy
from typing import TYPE_CHECKING, Any
from typing import Any

import numpy as np

Expand All @@ -11,9 +11,6 @@
from tracksdata.options import get_options
from tracksdata.utils._dtypes import polars_dtype_to_numpy_dtype

if TYPE_CHECKING:
from tracksdata.nodes._mask import Mask


def _validate_shape(
shape: tuple[int, ...] | None,
Expand Down Expand Up @@ -346,14 +343,17 @@ def _fill_array(self, time: int, volume_slicing: Sequence[slice], buffer: np.nda
np.ndarray
The filled buffer.
"""
# Local import: avoids the graph <-> nodes package import cycle (importing
# tracksdata.nodes re-enters the partially-initialized graph package).
from tracksdata.nodes._mask import as_mask

subgraph = self._spatial_filter[(slice(time, time), *volume_slicing)]
df = subgraph.node_attrs(
attr_keys=[self._attr_key, DEFAULT_ATTR_KEYS.MASK],
)

for mask, value in zip(df[DEFAULT_ATTR_KEYS.MASK], df[self._attr_key], strict=True):
mask: Mask
mask.paint_buffer(buffer, value, offset=self._offset)
as_mask(mask).paint_buffer(buffer, value, offset=self._offset)

def _offset_as_array(self, ndim: int) -> np.ndarray:
"""Normalize `offset` to a vector for each spatial axis."""
Expand Down
9 changes: 7 additions & 2 deletions src/tracksdata/edges/_iou_edges.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from tracksdata.constants import DEFAULT_ATTR_KEYS
from tracksdata.edges._generic_edges import GenericFuncEdgeAttrs
from tracksdata.nodes._mask import Mask
from tracksdata.nodes._mask import Mask, as_mask


def _mask_iou(source_mask: "Mask | dict", target_mask: "Mask | dict") -> float:
"""IoU between two mask attribute values (struct dicts or `Mask` instances)."""
return as_mask(source_mask).iou(as_mask(target_mask))


class IoUEdgeAttr(GenericFuncEdgeAttrs):
Expand All @@ -22,7 +27,7 @@ def __init__(
mask_key: str = DEFAULT_ATTR_KEYS.MASK,
):
super().__init__(
func=Mask.iou,
func=_mask_iou,
attr_keys=mask_key,
output_key=output_key,
)
5 changes: 3 additions & 2 deletions src/tracksdata/edges/_test/test_iou_edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from tracksdata.constants import DEFAULT_ATTR_KEYS
from tracksdata.edges import IoUEdgeAttr
from tracksdata.edges._iou_edges import _mask_iou
from tracksdata.graph import RustWorkXGraph
from tracksdata.nodes import Mask
from tracksdata.options import get_options, options_context
Expand All @@ -15,7 +16,7 @@ def test_iou_edges_init_default() -> None:

assert operator.output_key == "iou_score"
assert operator.attr_keys == DEFAULT_ATTR_KEYS.MASK
assert operator.func == Mask.iou
assert operator.func == _mask_iou


def test_iou_edges_init_custom() -> None:
Expand All @@ -24,7 +25,7 @@ def test_iou_edges_init_custom() -> None:

assert operator.output_key == "custom_iou"
assert operator.attr_keys == "custom_mask"
assert operator.func == Mask.iou
assert operator.func == _mask_iou


@pytest.mark.parametrize("n_workers", [1, 2])
Expand Down
4 changes: 2 additions & 2 deletions src/tracksdata/functional/_test/test_napari.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from tracksdata.constants import DEFAULT_ATTR_KEYS
from tracksdata.functional import to_napari_format
from tracksdata.graph import RustWorkXGraph
from tracksdata.nodes import MaskDiskAttrs
from tracksdata.nodes import MaskDiskAttrs, as_mask


@pytest.mark.parametrize("metadata_shape", [True, False])
Expand Down Expand Up @@ -47,7 +47,7 @@ def test_napari_conversion(metadata_shape: bool) -> None:
graph.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, dtype=pl.Array(pl.Int64, 6))
masks = graph.node_attrs(attr_keys=[DEFAULT_ATTR_KEYS.MASK])[DEFAULT_ATTR_KEYS.MASK]
graph.update_node_attrs(
attrs={DEFAULT_ATTR_KEYS.BBOX: [mask.bbox for mask in masks]},
attrs={DEFAULT_ATTR_KEYS.BBOX: [as_mask(mask).bbox for mask in masks]},
node_ids=graph.node_ids(),
)

Expand Down
10 changes: 8 additions & 2 deletions src/tracksdata/graph/_base_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1378,11 +1378,14 @@ def compute_overlaps(self, iou_threshold: float = 0.0) -> None:
raise ValueError("iou_threshold must be between 0.0 and 1.0")

def _estimate_overlaps(t: int) -> list[list[int, 2]]:
# Local import: avoids the graph <-> nodes package import cycle.
from tracksdata.nodes._mask import as_mask

node_attrs = self.filter(NodeAttr(DEFAULT_ATTR_KEYS.T) == t).node_attrs(
attr_keys=[DEFAULT_ATTR_KEYS.NODE_ID, DEFAULT_ATTR_KEYS.MASK],
)
node_ids = node_attrs[DEFAULT_ATTR_KEYS.NODE_ID].to_list()
masks = node_attrs[DEFAULT_ATTR_KEYS.MASK].to_list()
masks = [as_mask(m) for m in node_attrs[DEFAULT_ATTR_KEYS.MASK].to_list()]
overlaps = []
for i in range(len(masks)):
mask_i = masks[i]
Expand Down Expand Up @@ -1939,8 +1942,11 @@ def to_geff(
}

if DEFAULT_ATTR_KEYS.MASK in node_attrs.columns:
# Local import: avoids the graph <-> nodes package import cycle.
from tracksdata.nodes._mask import as_mask

node_dict[DEFAULT_ATTR_KEYS.MASK] = construct_var_len_props(
[mask.mask.astype(np.uint64) for mask in node_attrs[DEFAULT_ATTR_KEYS.MASK]]
[as_mask(mask).mask.astype(np.uint64) for mask in node_attrs[DEFAULT_ATTR_KEYS.MASK]]
)

edge_dict = {k: {"values": column_to_numpy(v), "missing": None} for k, v in edge_attrs.to_dict().items()}
Expand Down
73 changes: 60 additions & 13 deletions src/tracksdata/graph/_sql_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from tracksdata.graph._base_graph import BaseGraph
from tracksdata.graph.filters._base_filter import BaseFilter
from tracksdata.utils._cache import cache_method
from tracksdata.utils._dataframe import unpack_array_attrs, unpickle_bytes_columns
from tracksdata.utils._dataframe import unpack_array_attrs, unpickle_columns
from tracksdata.utils._dtypes import (
STRUCT_FIELD_SEP,
AttrSchema,
Expand Down Expand Up @@ -422,7 +422,7 @@ def _read_attr_dataframe(self, query: sa.Select, table: type[DeclarativeBase]) -
schema_overrides=self._graph._polars_schema_override(table),
)

df = unpickle_bytes_columns(df)
df = unpickle_columns(df, self._graph._pickled_physical_columns(table))
return self._graph._cast_columns(table, df)

def _query_from_attr_keys(
Expand Down Expand Up @@ -656,6 +656,11 @@ def __init__(
self._engine_kwargs = engine_kwargs if engine_kwargs is not None else {}
self._engine = sa.create_engine(self._url, **self._engine_kwargs)

# Initialized before `_define_schema`, which (when reloading an existing
# database) reads the attribute schemas to restore pickle column types.
self._node_attr_schemas_cache: dict | None = None
self._edge_attr_schemas_cache: dict | None = None

# Create unique classes for this instance
self._define_schema(overwrite=overwrite)

Expand All @@ -666,8 +671,6 @@ def __init__(

self._max_id_per_time = {}
self._update_max_id_per_time()
self._node_attr_schemas_cache: dict | None = None
self._edge_attr_schemas_cache: dict | None = None

def supports_custom_indices(self) -> bool:
return True
Expand All @@ -686,8 +689,6 @@ class Base(DeclarativeBase):
pass

if len(metadata.tables) > 0 and not overwrite:
for table in metadata.tables.values():
self._restore_pickled_column_types(table)
for table_name, table in metadata.tables.items():
cls = type(
table_name,
Expand All @@ -699,6 +700,10 @@ class Base(DeclarativeBase):
)
setattr(self, table_name, cls)
self.Base = Base
# Restore pickle column types only after the ORM classes exist, so
# `_pickled_physical_columns` can read the stored attribute schemas.
for table_name in metadata.tables:
self._restore_pickled_column_types(getattr(self, table_name))
return

class Node(Base):
Expand Down Expand Up @@ -794,7 +799,37 @@ def _attr_schemas_for_table(self, table_class: type[DeclarativeBase]) -> dict[st

@staticmethod
def _is_pickled_sql_type(column_type: TypeEngine) -> bool:
return isinstance(column_type, sa.PickleType | sa.LargeBinary)
"""Whether a SQL column stores pickled Python objects.

Only ``PickleType`` columns are pickled. Plain ``LargeBinary`` columns
hold raw bytes (e.g. the blosc2-compressed Mask ``data`` leaf) and must
not be unpickled. After reflection both report as ``LargeBinary``, so
:meth:`_restore_pickled_column_types` re-tags the genuinely-pickled ones
as ``PickleType`` before this check is used.
"""
return isinstance(column_type, sa.PickleType)

def _pickled_physical_columns(self, table_class: type[DeclarativeBase]) -> list[str]:
"""Physical column names whose values are pickled (vs stored natively)."""
return [col.name for col in table_class.__table__.columns if self._is_pickled_sql_type(col.type)]

def _raw_binary_physical_columns(self, table_class: type[DeclarativeBase]) -> set[str]:
"""Physical column names that hold raw ``pl.Binary`` bytes (never pickled).

These are the only blob columns left as ``LargeBinary`` after reflection;
every other blob column is re-tagged as ``PickleType``.
"""
raw: set[str] = set()
for key, schema in self._attr_schemas_for_table(table_class).items():
if isinstance(schema.dtype, pl.Struct):
raw.update(
flat_col
for flat_col, leaf_dtype in flatten_struct_dtype(key, schema.dtype)
if leaf_dtype == pl.Binary
)
elif schema.dtype == pl.Binary:
raw.add(key)
return raw

@property
def __node_attr_schemas(self) -> dict[str, AttrSchema]:
Expand Down Expand Up @@ -840,9 +875,21 @@ def __edge_attr_schemas(self, schemas: dict[str, AttrSchema]) -> None:
self._private_metadata[self._PRIVATE_SQL_EDGE_SCHEMA_STORE_KEY] = encoded_schemas
self._edge_attr_schemas_cache = None

def _restore_pickled_column_types(self, table: sa.Table) -> None:
for column in table.columns:
if isinstance(column.type, sa.LargeBinary):
def _restore_pickled_column_types(self, table_class: type[DeclarativeBase]) -> None:
"""Restore ``PickleType`` on reflected pickle columns.

Reflection reports every blob column as ``LargeBinary``, losing the
distinction between genuinely-pickled columns and raw-binary ones. We
consult the stored schema (via :meth:`_raw_binary_physical_columns`) and
re-tag every blob column as ``PickleType`` except the raw-binary ones
(e.g. the Mask ``data`` leaf), which stay ``LargeBinary`` so writes store
their bytes directly.
"""
if table_class.__tablename__ not in (self.Node.__tablename__, self.Edge.__tablename__):
return
raw_binary = self._raw_binary_physical_columns(table_class)
for column in table_class.__table__.columns:
if isinstance(column.type, sa.LargeBinary) and column.name not in raw_binary:
column.type = sa.PickleType()

def _polars_schema_override(self, table_class: type[DeclarativeBase]) -> SchemaDict:
Expand Down Expand Up @@ -1357,7 +1404,7 @@ def _get_neighbors(
filter_node_ids,
self.Node,
)
node_df = unpickle_bytes_columns(node_df)
node_df = unpickle_columns(node_df, self._pickled_physical_columns(self.Node))
node_df = self._cast_columns(self.Node, node_df)

if single_node:
Expand Down Expand Up @@ -1550,7 +1597,7 @@ def node_attrs(
connection=session.connection(),
schema_overrides=self._polars_schema_override(self.Node),
)
nodes_df = unpickle_bytes_columns(nodes_df)
nodes_df = unpickle_columns(nodes_df, self._pickled_physical_columns(self.Node))
nodes_df = self._cast_columns(self.Node, nodes_df)

# Select using logical keys (struct columns are now reconstructed).
Expand Down Expand Up @@ -1596,7 +1643,7 @@ def edge_attrs(
connection=session.connection(),
schema_overrides=self._polars_schema_override(self.Edge),
)
edges_df = unpickle_bytes_columns(edges_df)
edges_df = unpickle_columns(edges_df, self._pickled_physical_columns(self.Edge))
edges_df = self._cast_columns(self.Edge, edges_df)

if unpack:
Expand Down
45 changes: 40 additions & 5 deletions src/tracksdata/graph/_test/test_graph_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from tracksdata.graph import BaseGraph, IndexedRXGraph, RustWorkXGraph, SQLGraph
from tracksdata.io._numpy_array import from_array
from tracksdata.nodes import RegionPropsNodes
from tracksdata.nodes._mask import Mask
from tracksdata.nodes._mask import Mask, as_mask
from tracksdata.utils._dtypes import STRUCT_FIELD_SEP


def test_already_existing_keys(graph_backend: BaseGraph) -> None:
Expand Down Expand Up @@ -1653,10 +1654,8 @@ def build_node_map(graph: BaseGraph) -> dict[tuple[int, tuple[int, ...]], dict[s
assert target_row["y"] == pytest.approx(source_row["y"])
assert target_row["x"] == pytest.approx(source_row["x"])

source_mask = source_row[DEFAULT_ATTR_KEYS.MASK]
target_mask = target_row[DEFAULT_ATTR_KEYS.MASK]
assert isinstance(source_mask, Mask)
assert isinstance(target_mask, Mask)
source_mask = as_mask(source_row[DEFAULT_ATTR_KEYS.MASK])
target_mask = as_mask(target_row[DEFAULT_ATTR_KEYS.MASK])
np.testing.assert_array_equal(source_mask.mask, target_mask.mask)
np.testing.assert_array_equal(source_mask.bbox, target_mask.bbox)

Expand Down Expand Up @@ -1763,6 +1762,42 @@ def test_sql_graph_mask_update_survives_reload(tmp_path: Path) -> None:
np.testing.assert_array_equal(stored_mask.mask, mask_data)


def test_sql_graph_mask_struct_stored_raw_not_pickled(tmp_path: Path) -> None:
"""Mask struct's binary `data` leaf is stored as raw BLOB, not double-pickled.

Regression for storing the (already blosc2-compressed) mask bytes through a
SQLAlchemy ``PickleType`` column, which wrapped them in a second pickle layer.
"""
db_path = tmp_path / "mask_struct.db"
graph = SQLGraph("sqlite", str(db_path))
graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, Mask.struct_dtype(2))

mask_data = np.array([[True, False], [False, True]], dtype=bool)
mask = Mask(mask_data, bbox=np.array([0, 0, 2, 2]))
node_id = graph.add_node({"t": 0, DEFAULT_ATTR_KEYS.MASK: mask.to_struct()})

data_col = f"{DEFAULT_ATTR_KEYS.MASK}{STRUCT_FIELD_SEP}{Mask.MASK_DATA_FIELD}"
# The binary leaf must be a raw LargeBinary column, never a PickleType.
assert isinstance(graph.Node.__table__.columns[data_col].type, sa.LargeBinary)
assert not isinstance(graph.Node.__table__.columns[data_col].type, sa.PickleType)

graph._engine.dispose()

reloaded = SQLGraph("sqlite", str(db_path))
# Still raw binary after reflection — not restored to PickleType.
assert isinstance(reloaded.Node.__table__.columns[data_col].type, sa.LargeBinary)
assert not isinstance(reloaded.Node.__table__.columns[data_col].type, sa.PickleType)

df = reloaded.node_attrs(attr_keys=[DEFAULT_ATTR_KEYS.MASK])
assert df.schema[DEFAULT_ATTR_KEYS.MASK] == Mask.struct_dtype(2)
restored = as_mask(df[DEFAULT_ATTR_KEYS.MASK].to_list()[0])
np.testing.assert_array_equal(restored.mask, mask_data)
np.testing.assert_array_equal(restored.bbox, mask.bbox)

# Struct-field filtering still works against the flat physical bbox columns.
assert reloaded.filter(NodeAttr(DEFAULT_ATTR_KEYS.MASK).struct.field("min_y") == 0).node_ids() == [node_id]


def test_sql_graph_struct_dtype_survives_reload(tmp_path: Path) -> None:
db_path = tmp_path / "struct_graph.db"
graph = SQLGraph("sqlite", str(db_path))
Expand Down
4 changes: 2 additions & 2 deletions src/tracksdata/metrics/_ctc_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,8 @@ def _matching_data(
]:
attr_keys = [DEFAULT_ATTR_KEYS.T, tracklet_id_key, *required_attrs]
nodes_df = graph.node_attrs(attr_keys=attr_keys)
if use_mask_serialization:
# required by multiprocessing
if use_mask_serialization and nodes_df[DEFAULT_ATTR_KEYS.MASK].dtype == pl.Object:
# required by multiprocessing; struct mask columns are Arrow-native and need no pickling
nodes_df = column_to_bytes(nodes_df, DEFAULT_ATTR_KEYS.MASK)
labels = {}

Expand Down
14 changes: 8 additions & 6 deletions src/tracksdata/metrics/_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,25 +144,27 @@ def compute_weights(
tuple[list[int], list[int], list[int], list[int], list[float]]
Matching data: mapped_ref, mapped_comp, rows, cols, weights (IoU values).
"""
# Local import: avoids the graph <-> nodes package import cycle.
from tracksdata.nodes._mask import as_mask
from tracksdata.utils._dtypes import column_from_bytes

# Handle serialized masks if needed
if ref_group[DEFAULT_ATTR_KEYS.MASK].dtype == pl.Binary:
ref_group = column_from_bytes(ref_group, DEFAULT_ATTR_KEYS.MASK)
comp_group = column_from_bytes(comp_group, DEFAULT_ATTR_KEYS.MASK)

# Materialize masks once, struct values decompress on conversion
ref_masks = [as_mask(m) for m in ref_group[DEFAULT_ATTR_KEYS.MASK]]
comp_masks = [as_mask(m) for m in comp_group[DEFAULT_ATTR_KEYS.MASK]]

mapped_ref = []
mapped_comp = []
rows = []
cols = []
weights = []

for i, (ref_id, ref_mask) in enumerate(
zip(ref_group[reference_graph_key], ref_group[DEFAULT_ATTR_KEYS.MASK], strict=True)
):
for j, (comp_id, comp_mask) in enumerate(
zip(comp_group[input_graph_key], comp_group[DEFAULT_ATTR_KEYS.MASK], strict=True)
):
for i, (ref_id, ref_mask) in enumerate(zip(ref_group[reference_graph_key], ref_masks, strict=True)):
for j, (comp_id, comp_mask) in enumerate(zip(comp_group[input_graph_key], comp_masks, strict=True)):
# Intersection over reference is used to select the matches
inter = ref_mask.intersection(comp_mask)
ctc_score = inter / ref_mask.size
Expand Down
Loading
Loading