From 22feee50fa5223b5cf3405a3a2257e841dfaa0af Mon Sep 17 00:00:00 2001 From: Mark Keller <7525285+keller-mark@users.noreply.github.com> Date: Thu, 3 Jul 2025 16:31:12 -0400 Subject: [PATCH 1/6] Geometry encoding parameter for shapes --- src/spatialdata/_core/spatialdata.py | 5 ++++- src/spatialdata/_io/io_shapes.py | 3 ++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 48f6386ca..f5a2bec2d 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -1179,6 +1179,7 @@ def write( overwrite: bool = False, consolidate_metadata: bool = True, format: SpatialDataFormat | list[SpatialDataFormat] | None = None, + shapes_geometry_encoding: 'WKB' | 'geoarrow' = 'WKB', ) -> None: """ Write the `SpatialData` object to a Zarr store. @@ -1223,6 +1224,7 @@ def write( element_name=element_name, overwrite=False, format=format, + shapes_geometry_encoding=shapes_geometry_encoding, ) if self.path != file_path: @@ -1241,6 +1243,7 @@ def _write_element( element_name: str, overwrite: bool, format: SpatialDataFormat | list[SpatialDataFormat] | None = None, + shapes_geometry_encoding: 'WKB' | 'geoarrow' = 'WKB', ) -> None: if not isinstance(zarr_container_path, Path): raise ValueError( @@ -1266,7 +1269,7 @@ def _write_element( elif element_type == "points": write_points(points=element, group=element_type_group, name=element_name, format=parsed["points"]) elif element_type == "shapes": - write_shapes(shapes=element, group=element_type_group, name=element_name, format=parsed["shapes"]) + write_shapes(shapes=element, group=element_type_group, name=element_name, format=parsed["shapes"], geometry_encoding=shapes_geometry_encoding) elif element_type == "tables": write_table(table=element, group=element_type_group, name=element_name, format=parsed["tables"]) else: diff --git a/src/spatialdata/_io/io_shapes.py b/src/spatialdata/_io/io_shapes.py index c32ce1f34..c0f74e80d 100644 --- a/src/spatialdata/_io/io_shapes.py +++ b/src/spatialdata/_io/io_shapes.py @@ -68,6 +68,7 @@ def write_shapes( name: str, group_type: str = "ngff:shapes", format: Format = CurrentShapesFormat(), + geometry_encoding: 'WKB' | 'geoarrow' = 'WKB', ) -> None: import numcodecs @@ -94,7 +95,7 @@ def write_shapes( attrs["version"] = format.spatialdata_format_version elif isinstance(format, ShapesFormatV02): path = Path(shapes_group._store.path) / shapes_group.path / "shapes.parquet" - shapes.to_parquet(path) + shapes.to_parquet(path, geometry_encoding=geometry_encoding) attrs = format.attrs_to_dict(shapes.attrs) attrs["version"] = format.spatialdata_format_version From 74690e4ff2b1d5644a1e43010324db8f94d01777 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 3 Jul 2025 20:36:07 +0000 Subject: [PATCH 2/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spatialdata/_core/spatialdata.py | 12 +++++++++--- src/spatialdata/_io/io_shapes.py | 2 +- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index f5a2bec2d..654a2908a 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -1179,7 +1179,7 @@ def write( overwrite: bool = False, consolidate_metadata: bool = True, format: SpatialDataFormat | list[SpatialDataFormat] | None = None, - shapes_geometry_encoding: 'WKB' | 'geoarrow' = 'WKB', + shapes_geometry_encoding: "WKB" | "geoarrow" = "WKB", ) -> None: """ Write the `SpatialData` object to a Zarr store. @@ -1243,7 +1243,7 @@ def _write_element( element_name: str, overwrite: bool, format: SpatialDataFormat | list[SpatialDataFormat] | None = None, - shapes_geometry_encoding: 'WKB' | 'geoarrow' = 'WKB', + shapes_geometry_encoding: "WKB" | "geoarrow" = "WKB", ) -> None: if not isinstance(zarr_container_path, Path): raise ValueError( @@ -1269,7 +1269,13 @@ def _write_element( elif element_type == "points": write_points(points=element, group=element_type_group, name=element_name, format=parsed["points"]) elif element_type == "shapes": - write_shapes(shapes=element, group=element_type_group, name=element_name, format=parsed["shapes"], geometry_encoding=shapes_geometry_encoding) + write_shapes( + shapes=element, + group=element_type_group, + name=element_name, + format=parsed["shapes"], + geometry_encoding=shapes_geometry_encoding, + ) elif element_type == "tables": write_table(table=element, group=element_type_group, name=element_name, format=parsed["tables"]) else: diff --git a/src/spatialdata/_io/io_shapes.py b/src/spatialdata/_io/io_shapes.py index c0f74e80d..93dec6c79 100644 --- a/src/spatialdata/_io/io_shapes.py +++ b/src/spatialdata/_io/io_shapes.py @@ -68,7 +68,7 @@ def write_shapes( name: str, group_type: str = "ngff:shapes", format: Format = CurrentShapesFormat(), - geometry_encoding: 'WKB' | 'geoarrow' = 'WKB', + geometry_encoding: "WKB" | "geoarrow" = "WKB", ) -> None: import numcodecs From fc7d5cee091f1835a61d49ceb5796a8cbf56eaf4 Mon Sep 17 00:00:00 2001 From: Mark Keller <7525285+keller-mark@users.noreply.github.com> Date: Fri, 18 Jul 2025 14:41:34 -0400 Subject: [PATCH 3/6] Update spatialdata.py --- src/spatialdata/_core/spatialdata.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 654a2908a..77844fcdb 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -1179,7 +1179,7 @@ def write( overwrite: bool = False, consolidate_metadata: bool = True, format: SpatialDataFormat | list[SpatialDataFormat] | None = None, - shapes_geometry_encoding: "WKB" | "geoarrow" = "WKB", + shapes_geometry_encoding: Literal["WKB", "geoarrow"] = "WKB", ) -> None: """ Write the `SpatialData` object to a Zarr store. @@ -1243,7 +1243,7 @@ def _write_element( element_name: str, overwrite: bool, format: SpatialDataFormat | list[SpatialDataFormat] | None = None, - shapes_geometry_encoding: "WKB" | "geoarrow" = "WKB", + shapes_geometry_encoding: Literal["WKB", "geoarrow"] = "WKB", ) -> None: if not isinstance(zarr_container_path, Path): raise ValueError( From 76dd54e3960509c79787dd454f1392ecc35469cd Mon Sep 17 00:00:00 2001 From: Mark Keller <7525285+keller-mark@users.noreply.github.com> Date: Fri, 18 Jul 2025 14:42:05 -0400 Subject: [PATCH 4/6] Update io_shapes.py --- src/spatialdata/_io/io_shapes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spatialdata/_io/io_shapes.py b/src/spatialdata/_io/io_shapes.py index 93dec6c79..2cdae06bc 100644 --- a/src/spatialdata/_io/io_shapes.py +++ b/src/spatialdata/_io/io_shapes.py @@ -68,7 +68,7 @@ def write_shapes( name: str, group_type: str = "ngff:shapes", format: Format = CurrentShapesFormat(), - geometry_encoding: "WKB" | "geoarrow" = "WKB", + geometry_encoding: Literal["WKB", "geoarrow"] = "WKB", ) -> None: import numcodecs From 8680540cfecd2913f760ab0b520e576fa688646c Mon Sep 17 00:00:00 2001 From: Mark Keller <7525285+keller-mark@users.noreply.github.com> Date: Fri, 18 Jul 2025 14:43:55 -0400 Subject: [PATCH 5/6] Update io_shapes.py --- src/spatialdata/_io/io_shapes.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/spatialdata/_io/io_shapes.py b/src/spatialdata/_io/io_shapes.py index 2cdae06bc..fc57b436c 100644 --- a/src/spatialdata/_io/io_shapes.py +++ b/src/spatialdata/_io/io_shapes.py @@ -1,5 +1,6 @@ from collections.abc import MutableMapping from pathlib import Path +from typing import Literal import numpy as np import zarr From 0637dac69e12c41819f120782aa98f38818d9b70 Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Mon, 5 Jan 2026 17:01:02 +0100 Subject: [PATCH 6/6] add setting for geometry_encoding; add tests --- .gitignore | 1 + src/spatialdata/__init__.py | 2 + src/spatialdata/_core/spatialdata.py | 10 ++-- src/spatialdata/_io/io_shapes.py | 9 +++- src/spatialdata/config.py | 32 +++++++++-- src/spatialdata/models/models.py | 8 +-- tests/io/test_readwrite.py | 80 ++++++++++++++++++++++++++-- 7 files changed, 124 insertions(+), 18 deletions(-) diff --git a/.gitignore b/.gitignore index dafbf6647..9a6001e35 100644 --- a/.gitignore +++ b/.gitignore @@ -52,3 +52,4 @@ node_modules/ .mypy_cache .ruff_cache +uv.lock diff --git a/src/spatialdata/__init__.py b/src/spatialdata/__init__.py index 2fb483505..cc94d6f62 100644 --- a/src/spatialdata/__init__.py +++ b/src/spatialdata/__init__.py @@ -40,6 +40,7 @@ "deepcopy", "sanitize_table", "sanitize_name", + "settings", ] from spatialdata import dataloader, datasets, models, transformations @@ -70,3 +71,4 @@ from spatialdata._io.format import SpatialDataFormatType from spatialdata._io.io_zarr import read_zarr from spatialdata._utils import get_pyramid_levels, unpad_raster +from spatialdata.config import settings diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 939185a42..b251548ae 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -1110,7 +1110,7 @@ def write( consolidate_metadata: bool = True, update_sdata_path: bool = True, sdata_formats: SpatialDataFormatType | list[SpatialDataFormatType] | None = None, - shapes_geometry_encoding: Literal["WKB", "geoarrow"] = "WKB", + shapes_geometry_encoding: Literal["WKB", "geoarrow"] | None = None, ) -> None: """ Write the `SpatialData` object to a Zarr store. @@ -1157,7 +1157,7 @@ def write( `spatialdata._io.format.py`. shapes_geometry_encoding Whether to use the WKB or geoarrow encoding for GeoParquet. See :meth:`geopandas.GeoDataFrame.to_parquet` - for details. + for details. If None, uses the value from :attr:`spatialdata.settings.shapes_geometry_encoding`. """ from spatialdata._io._utils import _resolve_zarr_store from spatialdata._io.format import _parse_formats @@ -1200,7 +1200,7 @@ def _write_element( element_name: str, overwrite: bool, parsed_formats: dict[str, SpatialDataFormatType] | None = None, - shapes_geometry_encoding: Literal["WKB", "geoarrow"] = "WKB", + shapes_geometry_encoding: Literal["WKB", "geoarrow"] | None = None, ) -> None: from spatialdata._io.io_zarr import _get_groups_for_element @@ -1270,7 +1270,7 @@ def write_element( element_name: str | list[str], overwrite: bool = False, sdata_formats: SpatialDataFormatType | list[SpatialDataFormatType] | None = None, - shapes_geometry_encoding: Literal["WKB", "geoarrow"] = "WKB", + shapes_geometry_encoding: Literal["WKB", "geoarrow"] | None = None, ) -> None: """ Write a single element, or a list of elements, to the Zarr store used for backing. @@ -1288,7 +1288,7 @@ def write_element( `SpatialData.write()`. shapes_geometry_encoding Whether to use the WKB or geoarrow encoding for GeoParquet. See :meth:`geopandas.GeoDataFrame.to_parquet` - for details. + for details. If None, uses the value from :attr:`spatialdata.settings.shapes_geometry_encoding`. Notes ----- diff --git a/src/spatialdata/_io/io_shapes.py b/src/spatialdata/_io/io_shapes.py index f354068a6..65cb099a0 100644 --- a/src/spatialdata/_io/io_shapes.py +++ b/src/spatialdata/_io/io_shapes.py @@ -70,7 +70,7 @@ def write_shapes( group: zarr.Group, group_type: str = "ngff:shapes", element_format: Format = CurrentShapesFormat(), - geometry_encoding: Literal["WKB", "geoarrow"] = "WKB", + geometry_encoding: Literal["WKB", "geoarrow"] | None = None, ) -> None: """Write shapes to spatialdata zarr store. @@ -89,8 +89,13 @@ def write_shapes( The format of the shapes element used to store it. geometry_encoding Whether to use the WKB or geoarrow encoding for GeoParquet. See :meth:`geopandas.GeoDataFrame.to_parquet` for - details. + details. If None, uses the value from :attr:`spatialdata.settings.shapes_geometry_encoding`. """ + from spatialdata.config import settings + + if geometry_encoding is None: + geometry_encoding = settings.shapes_geometry_encoding + axes = get_axes_names(shapes) transformations = _get_transformations(shapes) if transformations is None: diff --git a/src/spatialdata/config.py b/src/spatialdata/config.py index 309f20e4d..dab848b35 100644 --- a/src/spatialdata/config.py +++ b/src/spatialdata/config.py @@ -1,4 +1,28 @@ -# chunk sizes bigger than this value (bytes) can trigger a compression error -# https://github.com/scverse/spatialdata/issues/812#issuecomment-2559380276 -# so if we detect this during parsing/validation we raise a warning -LARGE_CHUNK_THRESHOLD_BYTES = 2147483647 +from dataclasses import dataclass +from typing import Literal + + +@dataclass +class Settings: + """Global settings for spatialdata. + + Attributes + ---------- + shapes_geometry_encoding + Default geometry encoding for GeoParquet files when writing shapes. + Can be "WKB" (Well-Known Binary) or "geoarrow". + See :meth:`geopandas.GeoDataFrame.to_parquet` for details. + large_chunk_threshold_bytes + Chunk sizes bigger than this value (bytes) can trigger a compression error. + See https://github.com/scverse/spatialdata/issues/812#issuecomment-2559380276 + If detected during parsing/validation, a warning is raised. + """ + + shapes_geometry_encoding: Literal["WKB", "geoarrow"] = "WKB" + large_chunk_threshold_bytes: int = 2147483647 + + +settings = Settings() + +# Backwards compatibility alias +LARGE_CHUNK_THRESHOLD_BYTES = settings.large_chunk_threshold_bytes diff --git a/src/spatialdata/models/models.py b/src/spatialdata/models/models.py index 92acbec44..ddda2a612 100644 --- a/src/spatialdata/models/models.py +++ b/src/spatialdata/models/models.py @@ -35,7 +35,7 @@ from spatialdata._logging import logger from spatialdata._types import ArrayLike from spatialdata._utils import _check_match_length_channels_c_dim -from spatialdata.config import LARGE_CHUNK_THRESHOLD_BYTES +from spatialdata.config import settings from spatialdata.models import C, X, Y, Z, get_axes_names from spatialdata.models._utils import ( DEFAULT_COORDINATE_SYSTEM, @@ -315,9 +315,9 @@ def _check_chunk_size_not_too_large(self, data: DataArray | DataTree) -> None: return n_elems = np.array(list(max_per_dimension.values())).prod().item() usage = n_elems * data.dtype.itemsize - if usage > LARGE_CHUNK_THRESHOLD_BYTES: + if usage > settings.large_chunk_threshold_bytes: warnings.warn( - f"Detected chunks larger than: {usage} > {LARGE_CHUNK_THRESHOLD_BYTES} bytes. " + f"Detected chunks larger than: {usage} > {settings.large_chunk_threshold_bytes} bytes. " "This can lead to low " "performance and memory issues downstream, and sometimes cause compression errors when writing " "(https://github.com/scverse/spatialdata/issues/812#issuecomment-2575983527). Please consider using" @@ -327,7 +327,7 @@ def _check_chunk_size_not_too_large(self, data: DataArray | DataTree) -> None: "2) Multiscale representations can be achieved by using the `scale_factors` argument in the " "`parse()` function.\n" "You can suppress this warning by increasing the value of " - "`spatialdata.config.LARGE_CHUNK_THRESHOLD_BYTES`.", + "`spatialdata.settings.large_chunk_threshold_bytes`.", UserWarning, stacklevel=2, ) diff --git a/tests/io/test_readwrite.py b/tests/io/test_readwrite.py index 11855a222..7ecd74205 100644 --- a/tests/io/test_readwrite.py +++ b/tests/io/test_readwrite.py @@ -3,17 +3,21 @@ import tempfile from collections.abc import Callable from pathlib import Path -from typing import Any +from typing import Any, Literal import dask.dataframe as dd import numpy as np +import pandas as pd +import pyarrow.parquet as pq import pytest import zarr from anndata import AnnData from numpy.random import default_rng +from shapely import MultiPolygon, Polygon from upath import UPath from zarr.errors import GroupNotFoundError +import spatialdata.config from spatialdata import SpatialData, deepcopy, read_zarr from spatialdata._core.validation import ValidationError from spatialdata._io._utils import _are_directories_identical, get_dask_backing_files @@ -74,20 +78,90 @@ def test_labels( sdata = SpatialData.read(tmpdir) assert_spatial_data_objects_are_identical(labels, sdata) + @pytest.mark.parametrize("geometry_encoding", ["WKB", "geoarrow"]) def test_shapes( self, tmp_path: str, shapes: SpatialData, sdata_container_format: SpatialDataContainerFormatType, + geometry_encoding: Literal["WKB", "geoarrow"], ) -> None: tmpdir = Path(tmp_path) / "tmp.zarr" # check the index is correctly written and then read shapes["circles"].index = np.arange(1, len(shapes["circles"]) + 1) - shapes.write(tmpdir, sdata_formats=sdata_container_format) + # add a mixed Polygon + MultiPolygon element + shapes["mixed"] = pd.concat([shapes["poly"], shapes["multipoly"]]) + + shapes.write(tmpdir, sdata_formats=sdata_container_format, shapes_geometry_encoding=geometry_encoding) sdata = SpatialData.read(tmpdir) - assert_spatial_data_objects_are_identical(shapes, sdata) + + if geometry_encoding == "WKB": + assert_spatial_data_objects_are_identical(shapes, sdata) + else: + # convert each Polygon to a MultiPolygon + mixed_multipolygon = shapes["mixed"].assign( + geometry=lambda df: df.geometry.apply(lambda g: MultiPolygon([g]) if isinstance(g, Polygon) else g) + ) + assert sdata["mixed"].equals(mixed_multipolygon) + assert not sdata["mixed"].equals(shapes["mixed"]) + + del shapes["mixed"] + del sdata["mixed"] + assert_spatial_data_objects_are_identical(shapes, sdata) + + @pytest.mark.parametrize("geometry_encoding", ["WKB", "geoarrow"]) + def test_shapes_geometry_encoding_write_element( + self, + tmp_path: str, + shapes: SpatialData, + sdata_container_format: SpatialDataContainerFormatType, + geometry_encoding: Literal["WKB", "geoarrow"], + ) -> None: + """Test shapes geometry encoding with write_element() and global settings.""" + tmpdir = Path(tmp_path) / "tmp.zarr" + + # First write an empty SpatialData to create the zarr store + empty_sdata = SpatialData() + empty_sdata.write(tmpdir, sdata_formats=sdata_container_format) + + shapes["mixed"] = pd.concat([shapes["poly"], shapes["multipoly"]]) + + # Add shapes to the empty sdata + for shape_name in shapes.shapes: + empty_sdata[shape_name] = shapes[shape_name] + + # Store original setting and set global encoding + original_encoding = spatialdata.config.settings.shapes_geometry_encoding + try: + spatialdata.config.settings.shapes_geometry_encoding = geometry_encoding + + # Write each shape element - should use global setting + for shape_name in shapes.shapes: + empty_sdata.write_element(shape_name, sdata_formats=sdata_container_format) + + # Verify the encoding metadata in the parquet file + parquet_file = tmpdir / "shapes" / shape_name / "shapes.parquet" + with pq.ParquetFile(parquet_file) as pf: + md = pf.metadata + d = json.loads(md.metadata[b"geo"].decode("utf-8")) + found_encoding = d["columns"]["geometry"]["encoding"] + if geometry_encoding == "WKB": + expected_encoding = "WKB" + elif shape_name == "circles": + expected_encoding = "point" + elif shape_name == "poly": + expected_encoding = "polygon" + elif shape_name in ["multipoly", "mixed"]: + expected_encoding = "multipolygon" + else: + raise ValueError( + f"Uncovered case for shape_name: {shape_name}, found encoding: {found_encoding}." + ) + assert found_encoding == expected_encoding + finally: + spatialdata.config.settings.shapes_geometry_encoding = original_encoding def test_points( self,