diff --git a/src/rook/io/__init__.py b/src/rook/io/__init__.py index 9c8c139..bdd9a4e 100644 --- a/src/rook/io/__init__.py +++ b/src/rook/io/__init__.py @@ -1,5 +1,5 @@ """Dataset input and output utilities.""" -from .datasets import open_dataset +from .datasets import DatasetSource, open_dataset -__all__ = ["open_dataset"] +__all__ = ["DatasetSource", "open_dataset"] diff --git a/src/rook/io/datasets.py b/src/rook/io/datasets.py index 28d5aaf..54513fb 100644 --- a/src/rook/io/datasets.py +++ b/src/rook/io/datasets.py @@ -1,5 +1,7 @@ """Utilities for detecting and opening supported datasets.""" +from collections.abc import Iterable +from dataclasses import dataclass from pathlib import Path from urllib.parse import urlsplit @@ -13,17 +15,54 @@ ZARR_EXT = ".zarr" -def open_dataset(ds_id, file_paths, apply_fixes=True): +@dataclass(frozen=True, init=False) +class DatasetSource: + """A normalized set of paths and its optional catalog dataset id.""" + + dataset_id: str | None + paths: tuple[str, ...] + + def __init__( + self, + dataset_id: str | None, + paths: str | Path | Iterable[str | Path], + ): + """Normalize and validate source paths.""" + if isinstance(paths, (str, Path)): + paths = (str(paths),) + else: + paths = tuple(str(path) for path in paths) + + if not paths: + raise ValueError("A dataset source requires at least one path.") + if len(paths) > 1 and any( + is_kerchunk_file(path) or is_zarr_store(path) for path in paths + ): + raise ValueError("Zarr and Kerchunk sources require exactly one path.") + + if dataset_id is not None: + dataset_id = str(dataset_id) + object.__setattr__(self, "dataset_id", dataset_id) + object.__setattr__(self, "paths", paths) + + @property + def key(self): + """Return the identifier used for operation result mappings.""" + return self.dataset_id or self.paths[0] + + +def open_dataset(source: DatasetSource, *, apply_fixes=True): """Open an xarray Dataset and optionally apply rook-native fixes.""" - zarr_store = get_zarr_store(ds_id, file_paths) + zarr_store = get_zarr_store(source) if zarr_store: ds = xr.open_zarr(zarr_store, **get_zarr_open_kwargs(zarr_store)) else: - open_kwargs = get_s3_open_kwargs(ds_id, file_paths) - ds = open_xr_dataset(file_paths, **open_kwargs) + open_kwargs = get_s3_open_kwargs(source) + paths = source.paths[0] if is_kerchunk_file(source.paths[0]) else list(source.paths) + ds = open_xr_dataset(paths, **open_kwargs) - if apply_fixes and not is_kerchunk_file(ds_id) and not is_zarr_store(ds_id): - ds = apply_dataset_fixes(ds_id, ds) + if apply_fixes and source.dataset_id: + ds = apply_dataset_fixes(source.dataset_id, ds) return ds @@ -81,16 +120,10 @@ def is_zarr_store(dset): return path.endswith(ZARR_EXT) -def get_zarr_store(ds_id, file_paths): - """Return a single Zarr store from a dataset id or resolved file paths.""" - if is_zarr_store(ds_id): - return str(ds_id) - - if isinstance(file_paths, (str, Path)): - return str(file_paths) if is_zarr_store(file_paths) else None - - if file_paths and len(file_paths) == 1 and is_zarr_store(file_paths[0]): - return str(file_paths[0]) +def get_zarr_store(source: DatasetSource): + """Return the store path when a source contains exactly one Zarr store.""" + if len(source.paths) == 1 and is_zarr_store(source.paths[0]): + return source.paths[0] return None @@ -107,11 +140,9 @@ def get_zarr_open_kwargs(store): return {"storage_options": storage_options} -def get_s3_open_kwargs(ds_id, file_paths): +def get_s3_open_kwargs(source: DatasetSource): """Return opener kwargs for S3-hosted NetCDF inputs.""" - dset = ds_id - if not isinstance(dset, str) and file_paths: - dset = str(file_paths[0]) + dset = source.paths[0] if not is_s3_uri(dset) or is_kerchunk_file(dset) or is_zarr_store(dset): return {} diff --git a/src/rook/utils/ops/concat.py b/src/rook/utils/ops/concat.py index 9753c62..a1b5514 100644 --- a/src/rook/utils/ops/concat.py +++ b/src/rook/utils/ops/concat.py @@ -71,9 +71,9 @@ def _calculate(self): new_collection = collections.OrderedDict() - for dset in self.collection: - ds_id = derive_ds_id(dset) - new_collection[ds_id] = dset.file_paths + for source in self.collection: + ds_id = source.dataset_id or derive_ds_id(source.paths[0]) + new_collection[ds_id] = source.paths norm_collection = patched_normalise(new_collection) diff --git a/src/rook/utils/ops/consolidate.py b/src/rook/utils/ops/consolidate.py index 7500be0..7c5ad4a 100644 --- a/src/rook/utils/ops/consolidate.py +++ b/src/rook/utils/ops/consolidate.py @@ -10,9 +10,14 @@ from loguru import logger from rook.catalog import get_catalog -from rook.io.datasets import is_kerchunk_file, is_s3_uri, is_zarr_store +from rook.io.datasets import ( + DatasetSource, + is_kerchunk_file, + is_s3_uri, + is_zarr_store, +) -from .helpers import ordered_dict, wrap_sequence +from .helpers import wrap_sequence def to_year(time_string): @@ -88,16 +93,13 @@ def consolidate(collection, **kwargs): project = get_project_name(collection[0]) catalog = get_catalog(project) - filtered_refs = ordered_dict() + sources = [] time_param = kwargs.get("time") for dset in collection: - if is_kerchunk_file(dset) or is_zarr_store(dset): - filtered_refs[dset] = dset - - elif is_s3_uri(dset): - filtered_refs[dset] = [dset] + if is_kerchunk_file(dset) or is_zarr_store(dset) or is_s3_uri(dset): + sources.append(DatasetSource(dataset_id=None, paths=dset)) elif not catalog: file_paths = dset_to_filepaths(dset, force=True) @@ -108,7 +110,8 @@ def consolidate(collection, **kwargs): if len(file_paths) == 0: raise Exception(f"No files found in given time range for {dset}") - filtered_refs[dset] = file_paths + dataset_id = None if isinstance(dset, FileMapper) else str(dset) + sources.append(DatasetSource(dataset_id=dataset_id, paths=file_paths)) else: ds_id = derive_ds_id(dset) @@ -125,6 +128,7 @@ def consolidate(collection, **kwargs): logger.info(f"Found {len(result)} files") - filtered_refs = result.files() + for dataset_id, paths in result.files().items(): + sources.append(DatasetSource(dataset_id=dataset_id, paths=paths)) - return filtered_refs + return tuple(sources) diff --git a/src/rook/utils/ops/normalise.py b/src/rook/utils/ops/normalise.py index ab17ad6..e066379 100644 --- a/src/rook/utils/ops/normalise.py +++ b/src/rook/utils/ops/normalise.py @@ -14,9 +14,9 @@ def normalise(collection, apply_fixes=True): logger.info(f"Working on datasets: {collection}") norm_collection = ordered_dict() - for dset, file_paths in collection.items(): - ds = open_dataset(dset, file_paths, apply_fixes) - norm_collection[dset] = ds + for source in collection: + ds = open_dataset(source, apply_fixes=apply_fixes) + norm_collection[source.key] = ds return norm_collection diff --git a/tests/test_ops_consolidate.py b/tests/test_ops_consolidate.py index 6182875..e567a61 100644 --- a/tests/test_ops_consolidate.py +++ b/tests/test_ops_consolidate.py @@ -1,6 +1,7 @@ import rook.utils.ops.consolidate as consolidate from rook import config from rook.catalog.base import Result +from rook.io.datasets import DatasetSource class DummyCollection: @@ -17,9 +18,12 @@ def fail_get_catalog(_project): collection = DummyCollection(["https://example.org/refs/mydataset.json"]) result = consolidate.consolidate(collection) - assert result == { - "https://example.org/refs/mydataset.json": "https://example.org/refs/mydataset.json" - } + assert result == ( + DatasetSource( + dataset_id=None, + paths=("https://example.org/refs/mydataset.json",), + ), + ) def test_consolidate_s3_bypasses_catalog_and_mapper(monkeypatch): @@ -39,9 +43,12 @@ def fail_dset_to_filepaths(_dset, **_kwargs): collection = DummyCollection(["s3://example-bucket/path/file.nc"]) result = consolidate.consolidate(collection) - assert result == { - "s3://example-bucket/path/file.nc": ["s3://example-bucket/path/file.nc"] - } + assert result == ( + DatasetSource( + dataset_id=None, + paths=("s3://example-bucket/path/file.nc",), + ), + ) def test_consolidate_zarr_bypasses_catalog_and_mapper(monkeypatch): @@ -56,7 +63,7 @@ def fail_lookup(*_args, **_kwargs): collection = DummyCollection([store]) result = consolidate.consolidate(collection) - assert result == {store: store} + assert result == (DatasetSource(dataset_id=None, paths=(store,)),) def test_consolidate_catalog_files_can_use_s3_base_dir(monkeypatch): @@ -84,8 +91,11 @@ def search(self, collection, time): collection = DummyCollection(["c3s-cmip6.dataset"]) result = consolidate.consolidate(collection) - assert result == { - "c3s-cmip6.dataset": [ - "s3://example-bucket/data/CMIP6/ScenarioMIP/Model/file_201501-210012.nc" - ] - } + assert result == ( + DatasetSource( + dataset_id="c3s-cmip6.dataset", + paths=( + "s3://example-bucket/data/CMIP6/ScenarioMIP/Model/file_201501-210012.nc", + ), + ), + ) diff --git a/tests/test_ops_helpers.py b/tests/test_ops_helpers.py index bf10c77..e23dc3b 100644 --- a/tests/test_ops_helpers.py +++ b/tests/test_ops_helpers.py @@ -1,9 +1,39 @@ +from dataclasses import FrozenInstanceError + +import pytest import xarray as xr from rook import config import rook.io.datasets as helpers +def source(dataset_id, paths): + return helpers.DatasetSource(dataset_id=dataset_id, paths=paths) + + +def test_dataset_source_normalizes_scalar_and_list_paths(): + assert source(None, "one.nc").paths == ("one.nc",) + assert source("project.dataset", ["one.nc", "two.nc"]).paths == ( + "one.nc", + "two.nc", + ) + + +def test_dataset_source_is_immutable(): + dataset_source = source(None, "one.nc") + + with pytest.raises(FrozenInstanceError): + dataset_source.paths = ("two.nc",) + + +def test_dataset_source_rejects_multiple_zarr_or_kerchunk_paths(): + with pytest.raises(ValueError, match="exactly one path"): + source(None, ["one.zarr", "two.zarr"]) + + with pytest.raises(ValueError, match="exactly one path"): + source(None, ["one.json", "two.json"]) + + def test_open_dataset_applies_fixes(monkeypatch): calls = {"open": 0, "fix": 0} @@ -22,7 +52,9 @@ def fake_apply(ds_id, ds): monkeypatch.setattr(helpers, "apply_dataset_fixes", fake_apply) monkeypatch.setattr(helpers, "is_kerchunk_file", lambda _: False) - result = helpers.open_dataset("project.dataset", ["a.nc"], apply_fixes=True) + result = helpers.open_dataset( + source("project.dataset", ["a.nc"]), apply_fixes=True + ) assert result == "FIXED" assert calls == {"open": 1, "fix": 1} @@ -43,7 +75,9 @@ def fake_apply(ds_id, ds): monkeypatch.setattr(helpers, "apply_dataset_fixes", fake_apply) monkeypatch.setattr(helpers, "is_kerchunk_file", lambda _: False) - result = helpers.open_dataset("project.dataset", ["a.nc"], apply_fixes=False) + result = helpers.open_dataset( + source("project.dataset", ["a.nc"]), apply_fixes=False + ) assert result == "DATASET" assert calls == {"open": 1, "fix": 0} @@ -64,12 +98,27 @@ def fake_apply(ds_id, ds): monkeypatch.setattr(helpers, "apply_dataset_fixes", fake_apply) monkeypatch.setattr(helpers, "is_kerchunk_file", lambda _: True) - result = helpers.open_dataset("kerchunk.json", ["a.nc"], apply_fixes=True) + result = helpers.open_dataset( + source(None, ["kerchunk.json"]), apply_fixes=True + ) assert result == "DATASET" assert calls == {"open": 1, "fix": 0} +def test_open_dataset_skips_fixes_without_catalog_dataset_id(monkeypatch): + monkeypatch.setattr(helpers, "open_xr_dataset", lambda _paths: "DATASET") + + def fail_apply_fixes(_ds_id, _ds): + raise AssertionError("Direct paths must not trigger project fixes") + + monkeypatch.setattr(helpers, "apply_dataset_fixes", fail_apply_fixes) + + result = helpers.open_dataset(source(None, "direct.nc"), apply_fixes=True) + + assert result == "DATASET" + + def test_is_kerchunk_file_local_json(): assert helpers.is_kerchunk_file("kerchunk.json") is True @@ -112,9 +161,12 @@ def test_is_zarr_store_netcdf_path(): def test_get_zarr_store_from_catalog_file_paths(): - assert helpers.get_zarr_store( - "project.dataset", ["s3://bucket/example.zarr"] - ) == "s3://bucket/example.zarr" + assert ( + helpers.get_zarr_store( + source("project.dataset", ["s3://bucket/example.zarr"]) + ) + == "s3://bucket/example.zarr" + ) def test_open_dataset_opens_local_zarr_store(tmp_path): @@ -122,7 +174,7 @@ def test_open_dataset_opens_local_zarr_store(tmp_path): expected = xr.Dataset({"tas": ("time", [280.0, 281.0])}) expected.to_zarr(store, mode="w") - result = helpers.open_dataset(str(store), str(store), apply_fixes=False) + result = helpers.open_dataset(source(None, str(store)), apply_fixes=False) xr.testing.assert_equal(result, expected) result.close() @@ -139,7 +191,7 @@ def fail_open_zarr(*_args, **_kwargs): monkeypatch.setattr(helpers.xr, "open_zarr", fail_open_zarr) result = helpers.open_dataset( - "project.dataset", [str(path)], apply_fixes=False + source("project.dataset", [str(path)]), apply_fixes=False ) xr.testing.assert_equal(result, expected) @@ -162,8 +214,7 @@ def fake_open_zarr(store, **kwargs): ) result = helpers.open_dataset( - "s3://example-bucket/path/example.zarr", - "s3://example-bucket/path/example.zarr", + source(None, "s3://example-bucket/path/example.zarr"), apply_fixes=False, ) @@ -188,7 +239,7 @@ def fail_apply_fixes(_ds_id, _ds): monkeypatch.setattr(helpers, "apply_dataset_fixes", fail_apply_fixes) result = helpers.open_dataset( - "/data/example.zarr", "/data/example.zarr", apply_fixes=True + source(None, "/data/example.zarr"), apply_fixes=True ) assert result == "DATASET" @@ -202,7 +253,7 @@ def test_get_s3_open_kwargs_for_s3_netcdf(monkeypatch): ) kwargs = helpers.get_s3_open_kwargs( - "s3://example-bucket/path/file.nc", ["s3://example-bucket/path/file.nc"] + source(None, "s3://example-bucket/path/file.nc") ) assert kwargs == { @@ -219,7 +270,7 @@ def test_get_s3_open_kwargs_without_s3_config(monkeypatch): monkeypatch.setattr(config, "CONFIG", {}) kwargs = helpers.get_s3_open_kwargs( - "s3://example-bucket/path/file.nc", ["s3://example-bucket/path/file.nc"] + source(None, "s3://example-bucket/path/file.nc") ) assert kwargs == {} @@ -233,7 +284,7 @@ def test_get_s3_open_kwargs_skips_kerchunk(monkeypatch): ) kwargs = helpers.get_s3_open_kwargs( - "s3://example-bucket/path/ref.json", ["s3://example-bucket/path/ref.json"] + source(None, "s3://example-bucket/path/ref.json") ) assert kwargs == {} @@ -279,8 +330,7 @@ def fake_open(file_paths, **kwargs): ) _ = helpers.open_dataset( - "s3://example-bucket/path/file.nc", - ["s3://example-bucket/path/file.nc"], + source(None, "s3://example-bucket/path/file.nc"), apply_fixes=False, )