diff --git a/src/tracksdata/_test/test_attrs.py b/src/tracksdata/_test/test_attrs.py index c2302f2d..6cab7d20 100644 --- a/src/tracksdata/_test/test_attrs.py +++ b/src/tracksdata/_test/test_attrs.py @@ -116,6 +116,23 @@ def test_attr_expr_method_delegation() -> None: assert result.to_list() == expected.to_list() +def test_attr_expr_struct_field_method_delegation() -> None: + df = pl.DataFrame({"s": [{"x": 1}, {"x": 2}, {"x": 3}]}, schema={"s": pl.Struct({"x": pl.Int64})}) + expr = NodeAttr("s").struct.field("x") + result = expr.evaluate(df) + assert isinstance(expr, NodeAttr) + assert result.to_list() == [1, 2, 3] + + +def test_attr_comparison_struct_field() -> None: + df = pl.DataFrame({"s": [{"x": 1}, {"x": 2}, {"x": 1}]}, schema={"s": pl.Struct({"x": pl.Int64})}) + comp = NodeAttr("s").struct.field("x") == 1 + result = comp.to_attr().evaluate(df) + assert comp.column == "s" + assert comp.attr.field_path == ("x",) + assert result.to_list() == [True, False, True] + + def test_attr_expr_complex_expression() -> None: df = pl.DataFrame({"iou": [0.5, 0.7, 0.9], "distance": [10, 20, 30]}) expr = (1 - Attr("iou")) * Attr("distance") diff --git a/src/tracksdata/attrs.py b/src/tracksdata/attrs.py index 589ef52e..c97755ed 100644 --- a/src/tracksdata/attrs.py +++ b/src/tracksdata/attrs.py @@ -201,7 +201,10 @@ def __init__(self, attr: "Attr", op: Callable, other: ExprInput | MembershipExpr raise ValueError(f"Comparison operators are not supported for multiple columns. Found {columns}.") self.attr = attr - self.column = columns[0] + # Prefer the explicitly tracked root_column so struct-field comparisons + # (e.g. `NodeAttr("m").struct.field("x") == 1`) record the parent storage + # column ("m"), letting backends remap to their physical layout via field_path. + self.column = attr.root_column if attr.root_column is not None else columns[0] self.op = op # casting numpy scalars to python scalars @@ -216,14 +219,18 @@ def __init__(self, attr: "Attr", op: Callable, other: ExprInput | MembershipExpr self.other = other def __repr__(self) -> str: - return f"{type(self.attr).__name__}({self.column}) {_OPS_MATH_SYMBOLS[self.op]} {self.other}" + if self.attr.field_path: + column = ".".join([str(self.column), *self.attr.field_path]) + else: + column = str(self.column) + return f"{type(self.attr).__name__}({column}) {_OPS_MATH_SYMBOLS[self.op]} {self.other}" def to_attr(self) -> "Attr": """ Transform the comparison back to an [Attr][tracksdata.attrs.Attr] object. This is useful for evaluating the expression on a DataFrame. """ - return Attr(self.op(pl.col(self.column), self.other)) + return Attr(self.op(self.attr.expr, self.other)) @property def columns(self) -> list[str]: @@ -268,6 +275,39 @@ def __gt__(self, other: ExprInput) -> "Attr": ... def __ge__(self, other: ExprInput) -> "Attr": ... +class _StructNamespace: + """Wrapper around polars struct namespace that preserves Attr semantics. + + Polars' own ``Expr.struct.field(name)`` only updates the underlying expression; + it loses the parent column identity, which backends need to map a filter back + to its physical storage (e.g. SQL flat columns, dict lookups in rustworkx). + This wrapper proxies the namespace while threading ``root_column`` and + ``field_path`` through ``.field(...)`` calls. + """ + + def __init__(self, attr: "Attr") -> None: + self._attr = attr + self._namespace = attr.expr.struct + + def field(self, name: str) -> "Attr": + # preserve_field_path keeps the existing root/path before appending the new field. + out = self._attr._wrap(self._namespace.field(name), preserve_field_path=True) + # _namespace.field() always returns a polars Expr, so _wrap always yields an Attr here. + out._append_field_path(name) + return out + + def __getattr__(self, name: str) -> Any: + namespace_attr = getattr(self._namespace, name) + if callable(namespace_attr): + + @functools.wraps(namespace_attr) + def _wrapped(*args, **kwargs): + return self._attr._wrap(namespace_attr(*args, **kwargs)) + + return _wrapped + return namespace_attr + + class Attr: """ A class to compose an attribute expression for attribute filtering or value evaluation. @@ -292,30 +332,43 @@ class Attr: def __init__(self, value: ExprInput) -> None: self._inf_exprs = [] # expressions multiplied by +inf self._neg_inf_exprs = [] # expressions multiplied by -inf + # Path-tracking for backend filters: + # - root_column: top-level column used to store the value. + # - field_path: nested struct path from that root column. + self._root_column: str | None = None + self._field_path: tuple[str, ...] = () if isinstance(value, str): self.expr = pl.col(value) + self._root_column = value elif isinstance(value, Attr): self.expr = value.expr # Copy infinity tracking from the other AttrExpr self._inf_exprs = value.inf_exprs self._neg_inf_exprs = value.neg_inf_exprs + self._root_column = value.root_column + self._field_path = value.field_path elif isinstance(value, AttrComparison): attr = value.to_attr() self.expr = attr.expr self._inf_exprs = attr.inf_exprs self._neg_inf_exprs = attr.neg_inf_exprs + self._root_column = attr.root_column + self._field_path = attr.field_path elif isinstance(value, Expr): self.expr = value else: self.expr = pl.lit(value) - def _wrap(self, expr: ExprInput) -> Union["Attr", Any]: + def _wrap(self, expr: ExprInput, *, preserve_field_path: bool = False) -> Union["Attr", Any]: if isinstance(expr, Expr): - result = Attr(expr) + result = type(self)(expr) # Propagate infinity tracking result._inf_exprs = self._inf_exprs.copy() result._neg_inf_exprs = self._neg_inf_exprs.copy() + if preserve_field_path: + result._root_column = self._root_column + result._field_path = self._field_path return result return expr @@ -437,6 +490,33 @@ def evaluate(self, df: DataFrame) -> Series: def columns(self) -> list[str]: return list(dict.fromkeys(self.expr_columns + self.inf_columns + self.neg_inf_columns)) + @property + def root_column(self) -> str | None: + """ + Top-level column name from which this expression originates. + + Examples + -------- + `Attr("t").root_column == "t"` + `NodeAttr("measurements").struct.field("score").root_column == "measurements"` + """ + return self._root_column + + @property + def field_path(self) -> tuple[str, ...]: + """ + Nested struct-field path relative to [root_column][tracksdata.attrs.Attr.root_column]. + + Empty tuple means no nested access. + + Examples + -------- + `Attr("t").field_path == ()` + `NodeAttr("measurements").struct.field("score").field_path == ("score",)` + `NodeAttr("meta").struct.field("det").struct.field("conf").field_path == ("det", "conf")` + """ + return self._field_path + @property def inf_exprs(self) -> list["Attr"]: """Get the expressions multiplied by positive infinity.""" @@ -524,6 +604,9 @@ def __getattr__(self, attr: str) -> Any: if attr.startswith("_"): raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'") + if attr == "struct": + return _StructNamespace(self) + # To auto generate operator methods such as `.log()`` expr_attr = getattr(self.expr, attr) if callable(expr_attr): @@ -535,6 +618,12 @@ def _wrapped(*args, **kwargs): return _wrapped return expr_attr + def _append_field_path(self, field_name: str) -> None: + if self._root_column is None: + self._field_path = () + else: + self._field_path = (*self._field_path, field_name) + def __repr__(self) -> str: return f"Attr({self.expr})" @@ -880,4 +969,9 @@ def polars_reduce_attr_comps( """ if not attr_comps: raise ValueError("No attribute comparisons provided.") + # `f.to_attr().expr` lets each filter render its own expression. For + # `AttrComparison` over a struct field, that expression already drills into + # the struct (e.g. `pl.col("m").struct.field("x")`) rather than reading the + # bare column. For compound `AttrFilter`s, it returns the combined boolean + # expression. return pl.reduce(reduce_op, [f.to_attr().expr for f in attr_comps]) diff --git a/src/tracksdata/graph/_rustworkx_graph.py b/src/tracksdata/graph/_rustworkx_graph.py index ebef1885..254935e8 100644 --- a/src/tracksdata/graph/_rustworkx_graph.py +++ b/src/tracksdata/graph/_rustworkx_graph.py @@ -93,6 +93,34 @@ def _list_to_pl_series(key: str, values: list[Any], schema: AttrSchema) -> pl.Se return s +def _extract_field_path(value: Any, field_path: tuple[str, ...]) -> Any: + """Walk a struct field path through a Python attribute value. + + Rustworkx stores attributes as plain Python objects (typically dicts for + struct attrs) rather than polars columns, so struct-field filters can't be + pushed down into an expression — we walk the path manually here. We also + accept sequence- and attribute-style access to keep this robust for users + who pass nested dataclasses or tuples through the attr dict. + """ + for field in field_path: + if value is None: + return None + + if isinstance(value, dict): + value = value.get(field, None) + continue + + try: + value = value[field] + except (KeyError, IndexError, TypeError): + try: + value = getattr(value, field) + except AttributeError: + return None + + return value + + def _eval_filter( f: Filter, attrs: dict[str, Any], @@ -101,6 +129,8 @@ def _eval_filter( """Evaluate a single comparison or compound filter against an attrs dict.""" if isinstance(f, AttrComparison): value = attrs.get(f.column, schema[f.column].default_value) + if f.attr.field_path: + value = _extract_field_path(value, f.attr.field_path) return bool(f.op(value, f.other)) assert isinstance(f, AttrFilter) diff --git a/src/tracksdata/graph/_sql_graph.py b/src/tracksdata/graph/_sql_graph.py index 8dae996b..2c703c15 100644 --- a/src/tracksdata/graph/_sql_graph.py +++ b/src/tracksdata/graph/_sql_graph.py @@ -25,8 +25,11 @@ from tracksdata.utils._cache import cache_method from tracksdata.utils._dataframe import unpack_array_attrs, unpickle_bytes_columns from tracksdata.utils._dtypes import ( + STRUCT_FIELD_SEP, AttrSchema, deserialize_attr_schema, + flatten_struct_dtype, + flatten_struct_value, polars_dtype_to_sqlalchemy_type, process_attr_key_args, serialize_attr_schema, @@ -66,10 +69,37 @@ def _data_numpy_to_native(data: dict[str, Any]) -> None: data[k] = v.item() +def _resolve_attr_filter_column( + table: type[DeclarativeBase], + attr_filter: AttrComparison, +) -> Any: + """Return the SQLAlchemy column expression for an AttrComparison. + + For struct field paths (e.g. ``NodeAttr("m").struct.field("score")``), the + field path is joined with ``STRUCT_FIELD_SEP`` to form the physical flat + column name (e.g. ``m__score``), which is a native SQL column. + + Struct attributes are stored as one physical SQL column per leaf field + (not as JSON blobs) so that filtering on a struct field is a native SQL + predicate on a leaf column rather than a server-side JSON path lookup — + the parent ``Attr``'s ``root_column`` and ``field_path`` provide the + logical-to-physical mapping. + """ + if not attr_filter.attr.field_path: + return getattr(table, str(attr_filter.column)) + + flat_col = STRUCT_FIELD_SEP.join([str(attr_filter.column), *attr_filter.attr.field_path]) + return getattr(table, flat_col) + + def _to_sql_clause(f: Filter, table: type[DeclarativeBase]) -> Any: - """Translate an AttrComparison or AttrFilter into a SQLAlchemy clause.""" + """Translate an AttrComparison or AttrFilter into a SQLAlchemy clause. + + Routes ``AttrComparison`` leaves through ``_resolve_attr_filter_column`` so + struct-field comparisons resolve to the flat physical column. + """ if isinstance(f, AttrComparison): - return f.op(getattr(table, str(f.column)), f.other) + return f.op(_resolve_attr_filter_column(table, f), f.other) assert isinstance(f, AttrFilter) if f.op == "not": @@ -241,7 +271,11 @@ def __init__( if self._node_attr_comps: node_filtered = True # filtering nodes by attributes - self._node_query = _filter_query(self._node_query, self._graph.Node, self._node_attr_comps) + self._node_query = _filter_query( + self._node_query, + self._graph.Node, + self._node_attr_comps, + ) # if both node and edge attributes are filtered # we need to select subset of edges that belong to the filtered nodes @@ -257,17 +291,29 @@ def __init__( SourceNode, self._graph.Edge.source_id == SourceNode.node_id, ) - self._edge_query = _filter_query(self._edge_query, SourceNode, self._node_attr_comps) + self._edge_query = _filter_query( + self._edge_query, + SourceNode, + self._node_attr_comps, + ) if self._include_sources or include_none: self._edge_query = self._edge_query.join( TargetNode, self._graph.Edge.target_id == TargetNode.node_id, ) - self._edge_query = _filter_query(self._edge_query, TargetNode, self._node_attr_comps) + self._edge_query = _filter_query( + self._edge_query, + TargetNode, + self._node_attr_comps, + ) if self._edge_attr_comps: - self._edge_query = _filter_query(self._edge_query, self._graph.Edge, self._edge_attr_comps) + self._edge_query = _filter_query( + self._edge_query, + self._graph.Edge, + self._edge_attr_comps, + ) # we haven't filtered the nodes by attributes # so we only return the nodes that are in the edges @@ -357,26 +403,30 @@ def node_attrs( attr_keys=attr_keys, ) - with Session(self._graph._engine) as session: - nodes_attrs = pl.read_database( - self._graph._raw_query(query), - connection=session.connection(), - schema_overrides=self._graph._polars_schema_override(self._graph.Node), - ) + nodes_attrs = self._read_attr_dataframe(query, self._graph.Node) if attr_keys is not None: + attr_keys = list(dict.fromkeys(attr_keys)) nodes_attrs = nodes_attrs.select(attr_keys) - nodes_attrs = unpickle_bytes_columns(nodes_attrs) - nodes_attrs = self._graph._cast_array_columns(self._graph.Node, nodes_attrs) - if unpack: nodes_attrs = unpack_array_attrs(nodes_attrs) return nodes_attrs - @staticmethod + def _read_attr_dataframe(self, query: sa.Select, table: type[DeclarativeBase]) -> pl.DataFrame: + with Session(self._graph._engine) as session: + df = pl.read_database( + self._graph._raw_query(query), + connection=session.connection(), + schema_overrides=self._graph._polars_schema_override(table), + ) + + df = unpickle_bytes_columns(df) + return self._graph._cast_columns(table, df) + def _query_from_attr_keys( + self, query: sa.Select, table: type[DeclarativeBase], attr_keys: list[str] | None = None, @@ -390,14 +440,16 @@ def _query_from_attr_keys( LOG.info("Query attr_keys: %s", attr_keys) + flat_names = self._graph._physical_column_names(attr_keys, table) + if isinstance(query, sa.CompoundSelect): union_query = query.alias("u") query = sa.select( - *[getattr(union_query.c, key) for key in attr_keys], + *[getattr(union_query.c, name) for name in flat_names], ) else: query = query.with_only_columns( - *[getattr(table, key) for key in attr_keys], + *[getattr(table, name) for name in flat_names], ) LOG.info("Query after attr_keys selection:\n%s", query) @@ -417,15 +469,7 @@ def edge_attrs(self, attr_keys: list[str] | None = None, unpack: bool = False) - ], ) - with Session(self._graph._engine) as session: - edges_df = pl.read_database( - self._graph._raw_query(query), - connection=session.connection(), - schema_overrides=self._graph._polars_schema_override(self._graph.Edge), - ) - - edges_df = unpickle_bytes_columns(edges_df) - edges_df = self._graph._cast_array_columns(self._graph.Edge, edges_df) + edges_df = self._read_attr_dataframe(query, self._graph.Edge) if unpack: edges_df = unpack_array_attrs(edges_df) @@ -470,26 +514,23 @@ def subgraph( ], ) - with Session(self._graph._engine) as session: - node_query = session.execute(node_query) - edge_query = session.execute(edge_query) - - node_map_to_root = {} - node_map_from_root = {} - rx_graph = rx.PyDiGraph() - - for row in node_query.mappings().all(): - data = dict(row) - root_node_id = data.pop(DEFAULT_ATTR_KEYS.NODE_ID) - node_id = rx_graph.add_node(data) - node_map_to_root[node_id] = root_node_id - node_map_from_root[root_node_id] = node_id - - for row in edge_query.mappings().all(): - data = dict(row) - source_id = node_map_from_root[data.pop(DEFAULT_ATTR_KEYS.EDGE_SOURCE)] - target_id = node_map_from_root[data.pop(DEFAULT_ATTR_KEYS.EDGE_TARGET)] - rx_graph.add_edge(source_id, target_id, data) + nodes_df = self._read_attr_dataframe(node_query, self._graph.Node) + edges_df = self._read_attr_dataframe(edge_query, self._graph.Edge) + + node_map_to_root = {} + node_map_from_root = {} + rx_graph = rx.PyDiGraph() + + for data in nodes_df.iter_rows(named=True): + root_node_id = data.pop(DEFAULT_ATTR_KEYS.NODE_ID) + node_id = rx_graph.add_node(data) + node_map_to_root[node_id] = root_node_id + node_map_from_root[root_node_id] = node_id + + for data in edges_df.iter_rows(named=True): + source_id = node_map_from_root[data.pop(DEFAULT_ATTR_KEYS.EDGE_SOURCE)] + target_id = node_map_from_root[data.pop(DEFAULT_ATTR_KEYS.EDGE_TARGET)] + rx_graph.add_edge(source_id, target_id, data) graph = GraphView( rx_graph=rx_graph, @@ -725,27 +766,26 @@ def _attr_schemas_from_metadata( {key: deserialize_attr_schema(encoded_schema, key=key) for key, encoded_schema in encoded_schemas.items()} ) + # Compute the set of flat physical columns that belong to known struct schemas, + # so the legacy fallback below does not register them as independent logical keys. + # Without this, reloading a DB with a struct attribute "m" would re-expose + # "m__score" / "m__label" as their own top-level attributes. + known_flat_cols: set[str] = set() + for schema in schemas.values(): + if isinstance(schema.dtype, pl.Struct): + known_flat_cols.update(fc for fc, _ in flatten_struct_dtype(schema.key, schema.dtype)) + # Legacy databases may not have schema metadata for all columns. for column_name, column in table_class.__table__.columns.items(): - if column_name not in schemas: + if column_name not in schemas and column_name not in known_flat_cols: schemas[column_name] = AttrSchema( key=column_name, dtype=sqlalchemy_type_to_polars_dtype(column.type), ) - result = {} - - # return dictionary in preferred order - for source in ( - preferred_order, - table_class.__table__.columns.keys(), - schemas, - ): - for key in source: - if key in schemas: - result.setdefault(key, schemas[key]) - - return result + ordered_keys = [key for key in preferred_order if key in schemas] + ordered_keys.extend(key for key in schemas if key not in ordered_keys) + return {key: schemas[key] for key in ordered_keys} def _attr_schemas_for_table(self, table_class: type[DeclarativeBase]) -> dict[str, AttrSchema]: if table_class.__tablename__ == self.Node.__tablename__: @@ -806,38 +846,92 @@ def _restore_pickled_column_types(self, table: sa.Table) -> None: column.type = sa.PickleType() def _polars_schema_override(self, table_class: type[DeclarativeBase]) -> SchemaDict: + """Return polars dtype overrides for physical columns in *table_class*. + + Flat struct leaf columns are included with their native leaf dtypes. + Pickled columns are excluded here and handled in a second pass by + ``_cast_array_columns``. + """ + overrides: SchemaDict = {} schemas = self._attr_schemas_for_table(table_class) + table_cols = table_class.__table__.columns - # Return schema overrides for columns safely represented in SQL. - # Pickled columns are unpickled and casted in a second pass. - return { - key: schema.dtype - for key, schema in schemas.items() - if ( - key in table_class.__table__.columns - and not self._is_pickled_sql_type(table_class.__table__.columns[key].type) - ) - } + for key, schema in schemas.items(): + if isinstance(schema.dtype, pl.Struct): + # Emit overrides for each leaf physical column. + for flat_col, leaf_dtype in flatten_struct_dtype(key, schema.dtype): + if flat_col in table_cols and not self._is_pickled_sql_type(table_cols[flat_col].type): + overrides[flat_col] = leaf_dtype + elif key in table_cols and not self._is_pickled_sql_type(table_cols[key].type): + overrides[key] = schema.dtype + + return overrides - def _cast_array_columns(self, table_class: type[DeclarativeBase], df: pl.DataFrame) -> pl.DataFrame: + @staticmethod + def _build_struct_expr(key: str, dtype: pl.Struct) -> pl.Expr: + """Recursively build a ``pl.struct`` expression from flat leaf columns.""" + fields: list[pl.Expr] = [] + for field_name, field_dtype in dtype.to_schema().items(): + flat_col = f"{key}{STRUCT_FIELD_SEP}{field_name}" + if isinstance(field_dtype, pl.Struct): + fields.append(SQLGraph._build_struct_expr(flat_col, field_dtype).alias(field_name)) + else: + fields.append(pl.col(flat_col).alias(field_name)) + return pl.struct(fields) + + def _cast_columns(self, table_class: type[DeclarativeBase], df: pl.DataFrame) -> pl.DataFrame: + """Cast pickled columns to their target dtype and reconstruct struct columns.""" schemas = self._attr_schemas_for_table(table_class) + table_cols = table_class.__table__.columns casts: list[pl.Series] = [] + struct_keys: list[tuple[str, pl.Struct]] = [] + for key, schema in schemas.items(): - if key not in df.columns or key not in table_class.__table__.columns: + if isinstance(schema.dtype, pl.Struct): + # Cast any pickled flat leaf columns to their proper dtypes before + # reconstruction so Array/List fields have correct dtype. + for flat_col, leaf_dtype in flatten_struct_dtype(key, schema.dtype): + if flat_col not in df.columns or flat_col not in table_cols: + continue + if not self._is_pickled_sql_type(table_cols[flat_col].type): + continue + try: + casts.append(pl.Series(flat_col, df[flat_col].to_list(), dtype=leaf_dtype)) + except Exception: + continue + struct_keys.append((key, schema.dtype)) + continue + + if key not in df.columns or key not in table_cols: continue - if not self._is_pickled_sql_type(table_class.__table__.columns[key].type): + if not self._is_pickled_sql_type(table_cols[key].type): continue try: casts.append(pl.Series(key, df[key].to_list(), dtype=schema.dtype)) except Exception: - # Keep original dtype when values cannot be casted to the target schema. + # Keep original dtype when values cannot be cast to the target schema. continue if casts: df = df.with_columns(casts) + + # Reconstruct struct columns from their flat physical columns. + for key, dtype in struct_keys: + flat_cols = [fc for fc, _ in flatten_struct_dtype(key, dtype)] + present = [fc for fc in flat_cols if fc in df.columns] + if not present: + continue # struct was not part of this query; skip + missing = [fc for fc in flat_cols if fc not in df.columns] + if missing: + raise ValueError( + f"Struct attribute '{key}' is partially present in the DataFrame " + f"(missing: {missing}). Cannot reconstruct the struct column." + ) + df = df.with_columns(self._build_struct_expr(key, dtype).alias(key)).drop(flat_cols) + return df def _update_max_id_per_time(self) -> None: @@ -867,6 +961,27 @@ def filter( include_sources=include_sources, ) + def _flatten_attrs_for_write( + self, + attrs: dict[str, Any], + schemas: dict[str, AttrSchema], + ) -> dict[str, Any]: + """Expand struct-typed values into flat ``{leaf_col: value}`` pairs. + + Non-struct values are passed through unchanged. Called from every + write path (``bulk_add_nodes``, ``bulk_add_edges``, ``update_node_attrs``, + ``update_edge_attrs``) since the single-node/edge wrappers in + :class:`BaseGraph` now delegate to the bulk variants. + """ + result: dict[str, Any] = {} + for key, value in attrs.items(): + schema = schemas.get(key) + if schema is not None and isinstance(schema.dtype, pl.Struct) and isinstance(value, dict): + result.update(flatten_struct_value(key, value, schema.dtype)) + else: + result[key] = value + return result + def bulk_add_nodes( self, nodes: list[dict[str, Any]], @@ -928,7 +1043,11 @@ def bulk_add_nodes( node_ids.append(node_id) insert_rows.append({**node, DEFAULT_ATTR_KEYS.NODE_ID: node_id}) - self._chunked_sa_write(Session.bulk_insert_mappings, insert_rows, self.Node) + # Flatten struct-typed attrs into their physical leaf columns before write. + # Non-struct keys (incl. NODE_ID) pass through unchanged. + node_schemas = self._node_attr_schemas() + write_rows = [self._flatten_attrs_for_write(row, node_schemas) for row in insert_rows] + self._chunked_sa_write(Session.bulk_insert_mappings, write_rows, self.Node) emit_node_added_events(self.node_added, zip(node_ids, nodes, strict=True)) @@ -1043,9 +1162,12 @@ def bulk_add_edges( return [] return None + edge_schemas = self._edge_attr_schemas() for edge in edges: _data_numpy_to_native(edge) + edges = [self._flatten_attrs_for_write(edge, edge_schemas) for edge in edges] + if return_ids: with Session(self._engine) as session: result = session.execute(sa.insert(self.Edge).returning(self.Edge.edge_id), edges) @@ -1217,7 +1339,8 @@ def _get_neighbors( # all columns node_columns = [self.Node] else: - node_columns = [getattr(self.Node, key) for key in attr_keys] + # Expand struct logical keys to their flat physical columns. + node_columns = self._physical_cols_for_query(attr_keys, self.Node) query = session.query(getattr(self.Edge, node_key), *node_columns) query = query.join(self.Edge, getattr(self.Edge, neighbor_key) == self.Node.node_id) @@ -1235,7 +1358,7 @@ def _get_neighbors( self.Node, ) node_df = unpickle_bytes_columns(node_df) - node_df = self._cast_array_columns(self.Node, node_df) + node_df = self._cast_columns(self.Node, node_df) if single_node: if not return_attrs: @@ -1417,9 +1540,9 @@ def node_attrs( if attr_keys is not None: # making them unique attr_keys = list(dict.fromkeys(attr_keys)) - + # Expand struct logical keys to their flat physical columns. query = query.with_only_columns( - *[getattr(self.Node, key) for key in attr_keys], + *self._physical_cols_for_query(attr_keys, self.Node), ) nodes_df = pl.read_database( @@ -1428,9 +1551,9 @@ def node_attrs( schema_overrides=self._polars_schema_override(self.Node), ) nodes_df = unpickle_bytes_columns(nodes_df) - nodes_df = self._cast_array_columns(self.Node, nodes_df) + nodes_df = self._cast_columns(self.Node, nodes_df) - # indices are included by default and must be removed + # Select using logical keys (struct columns are now reconstructed). if attr_keys is not None: nodes_df = nodes_df.select([pl.col(c) for c in attr_keys]) else: @@ -1463,8 +1586,9 @@ def edge_attrs( LOG.info("Edge attribute keys: %s", attr_keys) + # Expand struct logical keys to their flat physical columns. query = query.with_only_columns( - *[getattr(self.Edge, key) for key in attr_keys], + *self._physical_cols_for_query(attr_keys, self.Edge), ) edges_df = pl.read_database( @@ -1473,7 +1597,7 @@ def edge_attrs( schema_overrides=self._polars_schema_override(self.Edge), ) edges_df = unpickle_bytes_columns(edges_df) - edges_df = self._cast_array_columns(self.Edge, edges_df) + edges_df = self._cast_columns(self.Edge, edges_df) if unpack: edges_df = unpack_array_attrs(edges_df) @@ -1488,6 +1612,41 @@ def _node_attr_schemas(self) -> dict[str, AttrSchema]: def _edge_attr_schemas(self) -> dict[str, AttrSchema]: return self.__edge_attr_schemas + @staticmethod + def _leaf_column_names(key: str, schema: AttrSchema | None) -> list[str]: + """Physical column name(s) backing a single logical attribute *key*. + + Struct attributes are stored as one physical column per leaf field, so they + expand to ``["key__a", "key__b", ...]``; every other attribute maps to the + single column ``[key]``. + """ + if schema is not None and isinstance(schema.dtype, pl.Struct): + return [flat_col for flat_col, _ in flatten_struct_dtype(key, schema.dtype)] + return [key] + + def _physical_column_names( + self, + logical_keys: Sequence[str], + table_class: type[DeclarativeBase], + ) -> list[str]: + """Expand logical attribute keys to the physical column names backing them. + + Logical keys are what the user sees (``"measurements"``); physical columns are + what actually exists in the table (``"measurements__score"``, ...). The two + diverge only for struct attributes; ``_cast_columns`` reassembles the struct + on the result DataFrame. + """ + schemas = self._attr_schemas_for_table(table_class) + return [name for key in logical_keys for name in self._leaf_column_names(key, schemas.get(key))] + + def _physical_cols_for_query( + self, + logical_keys: Sequence[str], + table_class: type[DeclarativeBase], + ) -> list[Any]: + """Like :meth:`_physical_column_names`, but returning SQLAlchemy column objects.""" + return [getattr(table_class, name) for name in self._physical_column_names(logical_keys, table_class)] + def node_attr_keys(self, return_ids: bool = False) -> list[str]: """ Get the keys of the attributes of the nodes. @@ -1498,7 +1657,10 @@ def node_attr_keys(self, return_ids: bool = False) -> list[str]: Whether to include NODE_ID in the returned keys. Defaults to False. If True, NODE_ID will be included in the list. """ - keys = list(self.Node.__table__.columns.keys()) + # Read from schemas (logical keys), not __table__.columns — the latter exposes + # struct leaves (``measurements__score``) as separate keys, but the public API + # should only surface the parent ``measurements``. + keys = list(self._node_attr_schemas().keys()) if not return_ids and DEFAULT_ATTR_KEYS.NODE_ID in keys: keys.remove(DEFAULT_ATTR_KEYS.NODE_ID) return keys @@ -1513,7 +1675,7 @@ def edge_attr_keys(self, return_ids: bool = False) -> list[str]: Whether to include EDGE_ID, EDGE_SOURCE, and EDGE_TARGET in the returned keys. Defaults to False. If True, these ID fields will be included in the list. """ - keys = list(self.Edge.__table__.columns.keys()) + keys = list(self._edge_attr_schemas().keys()) if not return_ids: for id_key in [DEFAULT_ATTR_KEYS.EDGE_ID, DEFAULT_ATTR_KEYS.EDGE_SOURCE, DEFAULT_ATTR_KEYS.EDGE_TARGET]: if id_key in keys: @@ -1546,13 +1708,12 @@ def _resolve_attr_keys( if len(attr_keys) == 0: raise ValueError("attr_keys must contain at least one column name") - missing = [key for key in attr_keys if key not in table_class.__table__.columns] + physical_names = self._physical_column_names(attr_keys, table_class) + + missing = [name for name in physical_names if name not in table_class.__table__.columns] if missing: raise ValueError(f"Columns {missing} do not exist on table {table_class.__tablename__}") - resolved_columns = [getattr(table_class, key) for key in attr_keys] - - if isinstance(attr_keys, str): - attr_keys = [attr_keys] + resolved_columns = [getattr(table_class, name) for name in physical_names] cols_fragment = "_".join(attr_keys) name = f"ix_{table_class.__tablename__.lower()}_{cols_fragment}" @@ -1706,28 +1867,24 @@ def _sqlalchemy_type_inference(self, default_value: Any) -> TypeEngine: else: raise ValueError(f"Unsupported default value type: {type(default_value)}") - def _add_new_column( + def _add_physical_column( self, table_class: type[DeclarativeBase], - schema: AttrSchema, + col_name: str, + sa_type: Any, + default_value: Any, ) -> None: - # Convert polars dtype to SQLAlchemy type - sa_type = polars_dtype_to_sqlalchemy_type(schema.dtype) - - # Handle special cases for default value encoding - default_value = schema.default_value + """Create a single physical SQL column and register it on the ORM class.""" if isinstance(sa_type, sa.PickleType) and default_value is not None: - # Pickle complex types for database storage default_value = blob_default(self._engine, cloudpickle.dumps(default_value)) - sa_column = sa.Column(schema.key, sa_type, default=default_value) + sa_column = sa.Column(col_name, sa_type, default=default_value) str_dialect_type = sa_column.type.compile(dialect=self._engine.dialect) identifier_preparer = self._engine.dialect.identifier_preparer quoted_table_name = identifier_preparer.format_table(table_class.__table__) quoted_column_name = identifier_preparer.quote(sa_column.name) - # Properly quote default values based on type if isinstance(default_value, str): quoted_default = f"'{default_value}'" elif default_value is None: @@ -1742,15 +1899,38 @@ def _add_new_column( ) LOG.info("add %s column statement:\n'%s'", table_class.__table__, add_column_stmt) - # create the new column in the database with Session(self._engine) as session: session.execute(add_column_stmt) session.commit() - # register the new column in the Node class - setattr(table_class, schema.key, sa_column) + setattr(table_class, col_name, sa_column) table_class.__table__.append_column(sa_column) + def _add_new_column( + self, + table_class: type[DeclarativeBase], + schema: AttrSchema, + ) -> None: + """Add a new attribute column (or flat leaf columns for structs) to *table_class*.""" + if isinstance(schema.dtype, pl.Struct): + # Expand struct into one physical column per leaf field. + flat_defaults = flatten_struct_value(schema.key, schema.default_value or {}, schema.dtype) + for flat_col, leaf_dtype in flatten_struct_dtype(schema.key, schema.dtype): + self._add_physical_column( + table_class, + flat_col, + polars_dtype_to_sqlalchemy_type(leaf_dtype), + flat_defaults.get(flat_col), + ) + return + + self._add_physical_column( + table_class, + schema.key, + polars_dtype_to_sqlalchemy_type(schema.dtype), + schema.default_value, + ) + def _drop_column(self, table_class: type[DeclarativeBase], key: str) -> None: identifier_preparer = self._engine.dialect.identifier_preparer quoted_table_name = identifier_preparer.format_table(table_class.__table__) @@ -1765,6 +1945,11 @@ def _drop_column(self, table_class: type[DeclarativeBase], key: str) -> None: # refresh ORM schema to reflect database changes self._define_schema(overwrite=False) + def _drop_attr_columns(self, table_class: type[DeclarativeBase], key: str, schema: AttrSchema | None) -> None: + """Drop the physical column(s) backing a logical attribute *key* (one per struct leaf).""" + for name in self._leaf_column_names(key, schema): + self._drop_column(table_class, name) + def add_node_attr_key( self, key_or_schema: str | AttrSchema, @@ -1788,7 +1973,7 @@ def remove_node_attr_key(self, key: str) -> None: raise ValueError(f"Cannot remove required node attribute key {key}") node_schemas = self.__node_attr_schemas - self._drop_column(self.Node, key) + self._drop_attr_columns(self.Node, key, node_schemas.get(key)) node_schemas.pop(key, None) self.__node_attr_schemas = node_schemas @@ -1812,7 +1997,7 @@ def remove_edge_attr_key(self, key: str) -> None: raise ValueError(f"Edge attribute key {key} does not exist") edge_schemas = self.__edge_attr_schemas - self._drop_column(self.Edge, key) + self._drop_attr_columns(self.Edge, key, edge_schemas.get(key)) edge_schemas.pop(key, None) self.__edge_attr_schemas = edge_schemas @@ -1851,6 +2036,8 @@ def _update_table( # Handle array values with bulk_update_mappings attrs = attrs.copy() _data_numpy_to_native(attrs) + schemas = self._attr_schemas_for_table(table_class) + attrs = self._flatten_attrs_for_write(attrs, schemas) # specialized case for scalar values - use simple bulk update if all(np.isscalar(v) for v in attrs.values()): diff --git a/src/tracksdata/graph/_test/test_graph_backends.py b/src/tracksdata/graph/_test/test_graph_backends.py index 96dbc6fd..5ec4315d 100644 --- a/src/tracksdata/graph/_test/test_graph_backends.py +++ b/src/tracksdata/graph/_test/test_graph_backends.py @@ -226,6 +226,30 @@ def test_filter_nodes_by_membership(graph_backend: BaseGraph) -> None: assert set(np_members) == {node_b} +def test_filter_nodes_by_struct_field(graph_backend: BaseGraph) -> None: + graph_backend.add_node_attr_key("measurements", pl.Struct({"score": pl.Int64, "name": pl.String})) + + node_a = graph_backend.add_node({"t": 0, "measurements": {"score": 1, "name": "A"}}) + node_b = graph_backend.add_node({"t": 1, "measurements": {"score": 2, "name": "B"}}) + node_c = graph_backend.add_node({"t": 2, "measurements": {"score": 1, "name": "C"}}) + + score_nodes = graph_backend.filter(NodeAttr("measurements").struct.field("score") == 1).node_ids() + assert set(score_nodes) == {node_a, node_c} + + name_nodes = graph_backend.filter(NodeAttr("measurements").struct.field("name") == "B").node_ids() + assert set(name_nodes) == {node_b} + + measurements = graph_backend.filter(node_ids=[node_a, node_c]).node_attrs(attr_keys=["measurements"]) + assert measurements.schema["measurements"] == pl.Struct({"score": pl.Int64, "name": pl.String}) + assert {m["name"] for m in measurements["measurements"].to_list()} == {"A", "C"} + + subgraph = graph_backend.filter(NodeAttr("measurements").struct.field("score") == 1).subgraph( + node_attr_keys=["measurements"] + ) + subgraph_measurements = subgraph.node_attrs(attr_keys=["measurements"]) + assert {m["name"] for m in subgraph_measurements["measurements"].to_list()} == {"A", "C"} + + def test_time_points(graph_backend: BaseGraph) -> None: """Test retrieving time points.""" graph_backend.add_node({"t": 0}) @@ -1739,6 +1763,24 @@ def test_sql_graph_mask_update_survives_reload(tmp_path: Path) -> None: np.testing.assert_array_equal(stored_mask.mask, mask_data) +def test_sql_graph_struct_dtype_survives_reload(tmp_path: Path) -> None: + db_path = tmp_path / "struct_graph.db" + graph = SQLGraph("sqlite", str(db_path)) + graph.add_node_attr_key("measurements", pl.Struct({"score": pl.Int64, "label": pl.String})) + + node_id = graph.add_node({"t": 0, "measurements": {"score": 7, "label": "A"}}) + graph._engine.dispose() + + reloaded = SQLGraph("sqlite", str(db_path)) + + df = reloaded.node_attrs(attr_keys=["measurements"]) + assert df.schema["measurements"] == pl.Struct({"score": pl.Int64, "label": pl.String}) + assert df["measurements"].to_list() == [{"score": 7, "label": "A"}] + + ids = reloaded.filter(NodeAttr("measurements").struct.field("score") == 7).node_ids() + assert ids == [node_id] + + def test_sql_graph_max_id_restored_per_timepoint(tmp_path: Path) -> None: """Reloading a SQLGraph should respect existing max IDs per time point.""" db_path = tmp_path / "id_restore.db" diff --git a/src/tracksdata/utils/_dtypes.py b/src/tracksdata/utils/_dtypes.py index 05338fa5..5c25b710 100644 --- a/src/tracksdata/utils/_dtypes.py +++ b/src/tracksdata/utils/_dtypes.py @@ -395,6 +395,71 @@ def infer_default_value_from_dtype(dtype: pl.DataType) -> Any: } +STRUCT_FIELD_SEP = "__" + + +def flatten_struct_dtype( + key: str, + dtype: pl.Struct, + sep: str = STRUCT_FIELD_SEP, +) -> list[tuple[str, pl.DataType]]: + """Recursively return ``(flat_column_name, leaf_dtype)`` for all leaves of a struct. + + Parameters + ---------- + key : str + The root column name (or already-accumulated flat prefix for nested calls). + dtype : pl.Struct + The struct dtype to flatten. + sep : str + Separator between path components. Defaults to ``STRUCT_FIELD_SEP``. + + Examples + -------- + >>> flatten_struct_dtype("m", pl.Struct({"score": pl.Int64, "label": pl.String})) + [("m__score", Int64), ("m__label", String)] + """ + results: list[tuple[str, pl.DataType]] = [] + for field_name, field_dtype in dtype.to_schema().items(): + flat_key = f"{key}{sep}{field_name}" + if isinstance(field_dtype, pl.Struct): + results.extend(flatten_struct_dtype(flat_key, field_dtype, sep)) + else: + results.append((flat_key, field_dtype)) + return results + + +def flatten_struct_value( + key: str, + value: dict | None, + dtype: pl.Struct, + sep: str = STRUCT_FIELD_SEP, +) -> dict: + """Flatten a struct dict value into ``{flat_col: scalar}`` pairs. + + Parameters + ---------- + key : str + The root column name. + value : dict + The struct value to flatten (may be ``None`` or empty). + dtype : pl.Struct + The struct dtype describing the expected fields. + sep : str + Separator. Defaults to ``STRUCT_FIELD_SEP``. + """ + result: dict = {} + value = value or {} + for field_name, field_dtype in dtype.to_schema().items(): + flat_key = f"{key}{sep}{field_name}" + field_val = value.get(field_name) + if isinstance(field_dtype, pl.Struct): + result.update(flatten_struct_value(flat_key, field_val, field_dtype, sep)) + else: + result[flat_key] = field_val + return result + + def polars_dtype_to_sqlalchemy_type(dtype: pl.DataType) -> TypeEngine: """ Convert a polars dtype to SQLAlchemy type.