Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/rook/io/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Dataset input and output utilities."""

from .datasets import open_dataset
from .datasets import DatasetSource, open_dataset

__all__ = ["open_dataset"]
__all__ = ["DatasetSource", "open_dataset"]
71 changes: 51 additions & 20 deletions src/rook/io/datasets.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Utilities for detecting and opening supported datasets."""

from collections.abc import Iterable
from dataclasses import dataclass
from pathlib import Path
from urllib.parse import urlsplit

Expand All @@ -13,17 +15,54 @@
ZARR_EXT = ".zarr"


def open_dataset(ds_id, file_paths, apply_fixes=True):
@dataclass(frozen=True, init=False)
class DatasetSource:
"""A normalized set of paths and its optional catalog dataset id."""

dataset_id: str | None
paths: tuple[str, ...]

def __init__(
self,
dataset_id: str | None,
paths: str | Path | Iterable[str | Path],
):
"""Normalize and validate source paths."""
if isinstance(paths, (str, Path)):
paths = (str(paths),)
else:
paths = tuple(str(path) for path in paths)

if not paths:
raise ValueError("A dataset source requires at least one path.")
if len(paths) > 1 and any(
is_kerchunk_file(path) or is_zarr_store(path) for path in paths
):
raise ValueError("Zarr and Kerchunk sources require exactly one path.")

if dataset_id is not None:
dataset_id = str(dataset_id)
object.__setattr__(self, "dataset_id", dataset_id)
object.__setattr__(self, "paths", paths)

@property
def key(self):
"""Return the identifier used for operation result mappings."""
return self.dataset_id or self.paths[0]


def open_dataset(source: DatasetSource, *, apply_fixes=True):
"""Open an xarray Dataset and optionally apply rook-native fixes."""
zarr_store = get_zarr_store(ds_id, file_paths)
zarr_store = get_zarr_store(source)
if zarr_store:
ds = xr.open_zarr(zarr_store, **get_zarr_open_kwargs(zarr_store))
else:
open_kwargs = get_s3_open_kwargs(ds_id, file_paths)
ds = open_xr_dataset(file_paths, **open_kwargs)
open_kwargs = get_s3_open_kwargs(source)
paths = source.paths[0] if is_kerchunk_file(source.paths[0]) else list(source.paths)
ds = open_xr_dataset(paths, **open_kwargs)

if apply_fixes and not is_kerchunk_file(ds_id) and not is_zarr_store(ds_id):
ds = apply_dataset_fixes(ds_id, ds)
if apply_fixes and source.dataset_id:
ds = apply_dataset_fixes(source.dataset_id, ds)

return ds

Expand Down Expand Up @@ -81,16 +120,10 @@ def is_zarr_store(dset):
return path.endswith(ZARR_EXT)


def get_zarr_store(ds_id, file_paths):
"""Return a single Zarr store from a dataset id or resolved file paths."""
if is_zarr_store(ds_id):
return str(ds_id)

if isinstance(file_paths, (str, Path)):
return str(file_paths) if is_zarr_store(file_paths) else None

if file_paths and len(file_paths) == 1 and is_zarr_store(file_paths[0]):
return str(file_paths[0])
def get_zarr_store(source: DatasetSource):
"""Return the store path when a source contains exactly one Zarr store."""
if len(source.paths) == 1 and is_zarr_store(source.paths[0]):
return source.paths[0]

return None

Expand All @@ -107,11 +140,9 @@ def get_zarr_open_kwargs(store):
return {"storage_options": storage_options}


def get_s3_open_kwargs(ds_id, file_paths):
def get_s3_open_kwargs(source: DatasetSource):
"""Return opener kwargs for S3-hosted NetCDF inputs."""
dset = ds_id
if not isinstance(dset, str) and file_paths:
dset = str(file_paths[0])
dset = source.paths[0]

if not is_s3_uri(dset) or is_kerchunk_file(dset) or is_zarr_store(dset):
return {}
Expand Down
6 changes: 3 additions & 3 deletions src/rook/utils/ops/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@ def _calculate(self):

new_collection = collections.OrderedDict()

for dset in self.collection:
ds_id = derive_ds_id(dset)
new_collection[ds_id] = dset.file_paths
for source in self.collection:
ds_id = source.dataset_id or derive_ds_id(source.paths[0])
new_collection[ds_id] = source.paths

norm_collection = patched_normalise(new_collection)

Expand Down
26 changes: 15 additions & 11 deletions src/rook/utils/ops/consolidate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,14 @@
from loguru import logger

from rook.catalog import get_catalog
from rook.io.datasets import is_kerchunk_file, is_s3_uri, is_zarr_store
from rook.io.datasets import (
DatasetSource,
is_kerchunk_file,
is_s3_uri,
is_zarr_store,
)

from .helpers import ordered_dict, wrap_sequence
from .helpers import wrap_sequence


def to_year(time_string):
Expand Down Expand Up @@ -88,16 +93,13 @@ def consolidate(collection, **kwargs):
project = get_project_name(collection[0])
catalog = get_catalog(project)

filtered_refs = ordered_dict()
sources = []

time_param = kwargs.get("time")

for dset in collection:
if is_kerchunk_file(dset) or is_zarr_store(dset):
filtered_refs[dset] = dset

elif is_s3_uri(dset):
filtered_refs[dset] = [dset]
if is_kerchunk_file(dset) or is_zarr_store(dset) or is_s3_uri(dset):
sources.append(DatasetSource(dataset_id=None, paths=dset))

elif not catalog:
file_paths = dset_to_filepaths(dset, force=True)
Expand All @@ -108,7 +110,8 @@ def consolidate(collection, **kwargs):
if len(file_paths) == 0:
raise Exception(f"No files found in given time range for {dset}")

filtered_refs[dset] = file_paths
dataset_id = None if isinstance(dset, FileMapper) else str(dset)
sources.append(DatasetSource(dataset_id=dataset_id, paths=file_paths))

else:
ds_id = derive_ds_id(dset)
Expand All @@ -125,6 +128,7 @@ def consolidate(collection, **kwargs):

logger.info(f"Found {len(result)} files")

filtered_refs = result.files()
for dataset_id, paths in result.files().items():
sources.append(DatasetSource(dataset_id=dataset_id, paths=paths))

return filtered_refs
return tuple(sources)
6 changes: 3 additions & 3 deletions src/rook/utils/ops/normalise.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ def normalise(collection, apply_fixes=True):
logger.info(f"Working on datasets: {collection}")
norm_collection = ordered_dict()

for dset, file_paths in collection.items():
ds = open_dataset(dset, file_paths, apply_fixes)
norm_collection[dset] = ds
for source in collection:
ds = open_dataset(source, apply_fixes=apply_fixes)
norm_collection[source.key] = ds

return norm_collection

Expand Down
34 changes: 22 additions & 12 deletions tests/test_ops_consolidate.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import rook.utils.ops.consolidate as consolidate
from rook import config
from rook.catalog.base import Result
from rook.io.datasets import DatasetSource


class DummyCollection:
Expand All @@ -17,9 +18,12 @@ def fail_get_catalog(_project):
collection = DummyCollection(["https://example.org/refs/mydataset.json"])
result = consolidate.consolidate(collection)

assert result == {
"https://example.org/refs/mydataset.json": "https://example.org/refs/mydataset.json"
}
assert result == (
DatasetSource(
dataset_id=None,
paths=("https://example.org/refs/mydataset.json",),
),
)


def test_consolidate_s3_bypasses_catalog_and_mapper(monkeypatch):
Expand All @@ -39,9 +43,12 @@ def fail_dset_to_filepaths(_dset, **_kwargs):
collection = DummyCollection(["s3://example-bucket/path/file.nc"])
result = consolidate.consolidate(collection)

assert result == {
"s3://example-bucket/path/file.nc": ["s3://example-bucket/path/file.nc"]
}
assert result == (
DatasetSource(
dataset_id=None,
paths=("s3://example-bucket/path/file.nc",),
),
)


def test_consolidate_zarr_bypasses_catalog_and_mapper(monkeypatch):
Expand All @@ -56,7 +63,7 @@ def fail_lookup(*_args, **_kwargs):
collection = DummyCollection([store])
result = consolidate.consolidate(collection)

assert result == {store: store}
assert result == (DatasetSource(dataset_id=None, paths=(store,)),)


def test_consolidate_catalog_files_can_use_s3_base_dir(monkeypatch):
Expand Down Expand Up @@ -84,8 +91,11 @@ def search(self, collection, time):
collection = DummyCollection(["c3s-cmip6.dataset"])
result = consolidate.consolidate(collection)

assert result == {
"c3s-cmip6.dataset": [
"s3://example-bucket/data/CMIP6/ScenarioMIP/Model/file_201501-210012.nc"
]
}
assert result == (
DatasetSource(
dataset_id="c3s-cmip6.dataset",
paths=(
"s3://example-bucket/data/CMIP6/ScenarioMIP/Model/file_201501-210012.nc",
),
),
)
Loading