From 796d0e15dcc464dc8b32b5ce6641b13fb9d6988b Mon Sep 17 00:00:00 2001 From: Carsten Ehbrecht Date: Thu, 18 Jun 2026 21:51:05 +0200 Subject: [PATCH] added DatasetFormat and dataset Transport --- src/rook/io/__init__.py | 4 +- src/rook/io/datasets.py | 144 ++++++++++++++++++------------ src/rook/utils/ops/consolidate.py | 22 +++-- tests/test_ops_helpers.py | 110 +++++++++++++---------- 4 files changed, 164 insertions(+), 116 deletions(-) diff --git a/src/rook/io/__init__.py b/src/rook/io/__init__.py index bdd9a4e..b549b1f 100644 --- a/src/rook/io/__init__.py +++ b/src/rook/io/__init__.py @@ -1,5 +1,5 @@ """Dataset input and output utilities.""" -from .datasets import DatasetSource, open_dataset +from .datasets import DatasetFormat, DatasetSource, Transport, open_dataset -__all__ = ["DatasetSource", "open_dataset"] +__all__ = ["DatasetFormat", "DatasetSource", "Transport", "open_dataset"] diff --git a/src/rook/io/datasets.py b/src/rook/io/datasets.py index 54513fb..4481565 100644 --- a/src/rook/io/datasets.py +++ b/src/rook/io/datasets.py @@ -2,6 +2,7 @@ from collections.abc import Iterable from dataclasses import dataclass +from enum import Enum from pathlib import Path from urllib.parse import urlsplit @@ -15,6 +16,24 @@ ZARR_EXT = ".zarr" +class DatasetFormat(Enum): + """Dataset formats supported by Rook.""" + + NETCDF = "netcdf" + ZARR = "zarr" + KERCHUNK = "kerchunk" + + +class Transport(Enum): + """Transport protocols relevant to dataset opening.""" + + FILESYSTEM = "filesystem" + HTTP = "http" + S3 = "s3" + REFERENCE = "reference" + OTHER = "other" + + @dataclass(frozen=True, init=False) class DatasetSource: """A normalized set of paths and its optional catalog dataset id.""" @@ -51,15 +70,63 @@ def key(self): return self.dataset_id or self.paths[0] +def detect_format(source: DatasetSource) -> DatasetFormat: + """Detect the data format independently of its transport protocol.""" + path = source.paths[0] + if is_zarr_store(path): + return DatasetFormat.ZARR + if is_kerchunk_file(path): + return DatasetFormat.KERCHUNK + return DatasetFormat.NETCDF + + +def detect_transport(source: DatasetSource) -> Transport: + """Detect and validate the transport shared by all source paths.""" + transports = {_detect_path_transport(path) for path in source.paths} + if len(transports) != 1: + names = ", ".join(sorted(transport.value for transport in transports)) + raise ValueError(f"Dataset paths use mixed transports: {names}.") + return transports.pop() + + +def get_storage_options(source: DatasetSource) -> dict: + """Return transport options for a dataset source.""" + if detect_transport(source) is Transport.S3: + return config.get_s3_storage_options() + return {} + + +def open_netcdf(source: DatasetSource, storage_options: dict): + """Open one or more NetCDF files through the established clisops opener.""" + kwargs = {} + if storage_options: + kwargs["backend_kwargs"] = {"storage_options": storage_options} + return open_xr_dataset(list(source.paths), **kwargs) + + +def open_zarr(source: DatasetSource, storage_options: dict): + """Open a single Zarr store.""" + kwargs = {"storage_options": storage_options} if storage_options else {} + return xr.open_zarr(source.paths[0], **kwargs) + + +def open_kerchunk(source: DatasetSource, storage_options: dict): + """Open a single Kerchunk reference through the established clisops path.""" + kwargs = {"target_options": storage_options} if storage_options else {} + return open_xr_dataset(source.paths[0], **kwargs) + + +_OPENERS = { + DatasetFormat.NETCDF: open_netcdf, + DatasetFormat.ZARR: open_zarr, + DatasetFormat.KERCHUNK: open_kerchunk, +} + + def open_dataset(source: DatasetSource, *, apply_fixes=True): """Open an xarray Dataset and optionally apply rook-native fixes.""" - zarr_store = get_zarr_store(source) - if zarr_store: - ds = xr.open_zarr(zarr_store, **get_zarr_open_kwargs(zarr_store)) - else: - open_kwargs = get_s3_open_kwargs(source) - paths = source.paths[0] if is_kerchunk_file(source.paths[0]) else list(source.paths) - ds = open_xr_dataset(paths, **open_kwargs) + opener = _OPENERS[detect_format(source)] + ds = opener(source, get_storage_options(source)) if apply_fixes and source.dataset_id: ds = apply_dataset_fixes(source.dataset_id, ds) @@ -89,21 +156,6 @@ def is_kerchunk_file(dset): return path.endswith(KERCHUNK_EXTS) -def is_s3_uri(dset): - """Return True when the input points to an S3 object URI.""" - if isinstance(dset, Path): - dset = str(dset) - - if not isinstance(dset, str): - return False - - value = dset.strip() - if not value: - return False - - return value.lower().startswith("s3://") - - def is_zarr_store(dset): """Return True when the input looks like a Zarr store path.""" if isinstance(dset, Path): @@ -120,40 +172,14 @@ def is_zarr_store(dset): return path.endswith(ZARR_EXT) -def get_zarr_store(source: DatasetSource): - """Return the store path when a source contains exactly one Zarr store.""" - if len(source.paths) == 1 and is_zarr_store(source.paths[0]): - return source.paths[0] - - return None - - -def get_zarr_open_kwargs(store): - """Return xarray opener kwargs for a Zarr store.""" - if not is_s3_uri(store): - return {} - - storage_options = get_s3_storage_options() - if not storage_options: - return {} - - return {"storage_options": storage_options} - - -def get_s3_open_kwargs(source: DatasetSource): - """Return opener kwargs for S3-hosted NetCDF inputs.""" - dset = source.paths[0] - - if not is_s3_uri(dset) or is_kerchunk_file(dset) or is_zarr_store(dset): - return {} - - storage_options = get_s3_storage_options() - if not storage_options: - return {} - - return {"backend_kwargs": {"storage_options": storage_options}} - - -def get_s3_storage_options(): - """Return shared S3 transport options from central configuration.""" - return config.get_s3_storage_options() +def _detect_path_transport(path: str) -> Transport: + scheme = urlsplit(path.strip()).scheme.lower() + if scheme in {"", "file"}: + return Transport.FILESYSTEM + if scheme in {"http", "https"}: + return Transport.HTTP + if scheme == "s3": + return Transport.S3 + if scheme == "reference": + return Transport.REFERENCE + return Transport.OTHER diff --git a/src/rook/utils/ops/consolidate.py b/src/rook/utils/ops/consolidate.py index 7c5ad4a..06016da 100644 --- a/src/rook/utils/ops/consolidate.py +++ b/src/rook/utils/ops/consolidate.py @@ -11,15 +11,25 @@ from rook.catalog import get_catalog from rook.io.datasets import ( + DatasetFormat, DatasetSource, - is_kerchunk_file, - is_s3_uri, - is_zarr_store, + Transport, + detect_format, + detect_transport, ) from .helpers import wrap_sequence +def _bypasses_catalog(value): + """Return whether a direct source should skip project resolution.""" + source = DatasetSource(dataset_id=None, paths=value) + return ( + detect_format(source) is not DatasetFormat.NETCDF + or detect_transport(source) is Transport.S3 + ) + + def to_year(time_string): """Return the year in a time string as an integer.""" return int(time_string.split("-")[0]) @@ -86,9 +96,7 @@ def consolidate(collection, **kwargs): if ( not isinstance(collection[0], FileMapper) - and not is_kerchunk_file(collection[0]) - and not is_s3_uri(collection[0]) - and not is_zarr_store(collection[0]) + and not _bypasses_catalog(collection[0]) ): project = get_project_name(collection[0]) catalog = get_catalog(project) @@ -98,7 +106,7 @@ def consolidate(collection, **kwargs): time_param = kwargs.get("time") for dset in collection: - if is_kerchunk_file(dset) or is_zarr_store(dset) or is_s3_uri(dset): + if not isinstance(dset, FileMapper) and _bypasses_catalog(dset): sources.append(DatasetSource(dataset_id=None, paths=dset)) elif not catalog: diff --git a/tests/test_ops_helpers.py b/tests/test_ops_helpers.py index e23dc3b..d1219d8 100644 --- a/tests/test_ops_helpers.py +++ b/tests/test_ops_helpers.py @@ -140,12 +140,14 @@ def test_is_kerchunk_file_non_kerchunk_path(): assert helpers.is_kerchunk_file("/data/file.nc") is False -def test_is_s3_uri_true(): - assert helpers.is_s3_uri("s3://my-bucket/path/file.nc") is True +def test_detect_format_netcdf(): + assert helpers.detect_format(source(None, "file.nc")) is helpers.DatasetFormat.NETCDF -def test_is_s3_uri_false_for_https(): - assert helpers.is_s3_uri("https://example.org/file.nc") is False +def test_detect_format_kerchunk_url_with_query(): + dataset_source = source(None, "https://example.org/ref.json?token=abc") + + assert helpers.detect_format(dataset_source) is helpers.DatasetFormat.KERCHUNK def test_is_zarr_store_local_path(): @@ -160,15 +162,32 @@ def test_is_zarr_store_netcdf_path(): assert helpers.is_zarr_store("s3://bucket/example.nc") is False -def test_get_zarr_store_from_catalog_file_paths(): +def test_detect_format_zarr_from_catalog_paths(): + assert helpers.detect_format( + source("project.dataset", ["s3://bucket/example.zarr"]) + ) is helpers.DatasetFormat.ZARR + + +def test_detect_transport_is_independent_of_format(): assert ( - helpers.get_zarr_store( - source("project.dataset", ["s3://bucket/example.zarr"]) - ) - == "s3://bucket/example.zarr" + helpers.detect_transport(source(None, "s3://bucket/file.nc")) + is helpers.Transport.S3 + ) + assert ( + helpers.detect_transport(source(None, "s3://bucket/example.zarr")) + is helpers.Transport.S3 + ) + assert ( + helpers.detect_transport(source(None, "https://example.org/ref.json")) + is helpers.Transport.HTTP ) +def test_detect_transport_rejects_mixed_transports(): + with pytest.raises(ValueError, match="mixed transports"): + helpers.detect_transport(source(None, ["/data/one.nc", "s3://bucket/two.nc"])) + + def test_open_dataset_opens_local_zarr_store(tmp_path): store = tmp_path / "example.zarr" expected = xr.Dataset({"tas": ("time", [280.0, 281.0])}) @@ -245,74 +264,69 @@ def fail_apply_fixes(_ds_id, _ds): assert result == "DATASET" -def test_get_s3_open_kwargs_for_s3_netcdf(monkeypatch): +def test_get_storage_options_for_s3_source(monkeypatch): monkeypatch.setattr( config, "CONFIG", {"s3": {"anon": "true", "endpoint_url": "https://s3.example.org"}}, ) - kwargs = helpers.get_s3_open_kwargs( - source(None, "s3://example-bucket/path/file.nc") - ) + options = helpers.get_storage_options(source(None, "s3://bucket/file.nc")) - assert kwargs == { - "backend_kwargs": { - "storage_options": { - "anon": True, - "client_kwargs": {"endpoint_url": "https://s3.example.org"}, - } - } + assert options == { + "anon": True, + "client_kwargs": {"endpoint_url": "https://s3.example.org"}, } -def test_get_s3_open_kwargs_without_s3_config(monkeypatch): +def test_get_storage_options_without_s3_config(monkeypatch): monkeypatch.setattr(config, "CONFIG", {}) - kwargs = helpers.get_s3_open_kwargs( - source(None, "s3://example-bucket/path/file.nc") - ) + options = helpers.get_storage_options(source(None, "s3://bucket/file.nc")) - assert kwargs == {} + assert options == {} -def test_get_s3_open_kwargs_skips_kerchunk(monkeypatch): +def test_get_storage_options_does_not_depend_on_format(monkeypatch): monkeypatch.setattr( config, "CONFIG", {"s3": {"anon": "true", "endpoint_url": "https://s3.example.org"}}, ) - kwargs = helpers.get_s3_open_kwargs( - source(None, "s3://example-bucket/path/ref.json") - ) + options = helpers.get_storage_options(source(None, "s3://bucket/ref.json")) - assert kwargs == {} + assert options["anon"] is True -def test_get_s3_storage_options_merges_client_kwargs(monkeypatch): - monkeypatch.setattr( - config, - "CONFIG", - { - "s3": { - "storage_options_json": '{"anon": true, "client_kwargs": {"region_name": "eu-west-1"}}', - "client_kwargs_json": '{"use_ssl": false}', - "endpoint_url": "https://s3.example.org", - } - }, +def test_open_dataset_passes_s3_options_to_kerchunk(monkeypatch): + calls = {} + + def fake_open(path, **kwargs): + calls["path"] = path + calls["kwargs"] = kwargs + return "DATASET" + + monkeypatch.setattr(helpers, "open_xr_dataset", fake_open) + monkeypatch.setattr(config, "CONFIG", {"s3": {"anon": "true"}}) + + result = helpers.open_dataset( + source(None, "s3://bucket/reference.json"), apply_fixes=False ) - assert helpers.get_s3_storage_options() == { - "anon": True, - "client_kwargs": { - "region_name": "eu-west-1", - "use_ssl": False, - "endpoint_url": "https://s3.example.org", - }, + assert result == "DATASET" + assert calls == { + "path": "s3://bucket/reference.json", + "kwargs": {"target_options": {"anon": True}}, } +def test_get_storage_options_skips_local_files(monkeypatch): + monkeypatch.setattr(config, "get_s3_storage_options", lambda: pytest.fail()) + + assert helpers.get_storage_options(source(None, "/data/file.nc")) == {} + + def test_open_dataset_passes_s3_backend_kwargs(monkeypatch): calls = {"open_kwargs": None}