Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/rook/io/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Dataset input and output utilities."""

from .datasets import DatasetSource, open_dataset
from .datasets import DatasetFormat, DatasetSource, Transport, open_dataset

__all__ = ["DatasetSource", "open_dataset"]
__all__ = ["DatasetFormat", "DatasetSource", "Transport", "open_dataset"]
144 changes: 85 additions & 59 deletions src/rook/io/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from collections.abc import Iterable
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from urllib.parse import urlsplit

Expand All @@ -15,6 +16,24 @@
ZARR_EXT = ".zarr"


class DatasetFormat(Enum):
"""Dataset formats supported by Rook."""

NETCDF = "netcdf"
ZARR = "zarr"
KERCHUNK = "kerchunk"


class Transport(Enum):
"""Transport protocols relevant to dataset opening."""

FILESYSTEM = "filesystem"
HTTP = "http"
S3 = "s3"
REFERENCE = "reference"
OTHER = "other"


@dataclass(frozen=True, init=False)
class DatasetSource:
"""A normalized set of paths and its optional catalog dataset id."""
Expand Down Expand Up @@ -51,15 +70,63 @@ def key(self):
return self.dataset_id or self.paths[0]


def detect_format(source: DatasetSource) -> DatasetFormat:
"""Detect the data format independently of its transport protocol."""
path = source.paths[0]
if is_zarr_store(path):
return DatasetFormat.ZARR
if is_kerchunk_file(path):
return DatasetFormat.KERCHUNK
return DatasetFormat.NETCDF


def detect_transport(source: DatasetSource) -> Transport:
"""Detect and validate the transport shared by all source paths."""
transports = {_detect_path_transport(path) for path in source.paths}
if len(transports) != 1:
names = ", ".join(sorted(transport.value for transport in transports))
raise ValueError(f"Dataset paths use mixed transports: {names}.")
return transports.pop()


def get_storage_options(source: DatasetSource) -> dict:
"""Return transport options for a dataset source."""
if detect_transport(source) is Transport.S3:
return config.get_s3_storage_options()
return {}


def open_netcdf(source: DatasetSource, storage_options: dict):
"""Open one or more NetCDF files through the established clisops opener."""
kwargs = {}
if storage_options:
kwargs["backend_kwargs"] = {"storage_options": storage_options}
return open_xr_dataset(list(source.paths), **kwargs)


def open_zarr(source: DatasetSource, storage_options: dict):
"""Open a single Zarr store."""
kwargs = {"storage_options": storage_options} if storage_options else {}
return xr.open_zarr(source.paths[0], **kwargs)


def open_kerchunk(source: DatasetSource, storage_options: dict):
"""Open a single Kerchunk reference through the established clisops path."""
kwargs = {"target_options": storage_options} if storage_options else {}
return open_xr_dataset(source.paths[0], **kwargs)


_OPENERS = {
DatasetFormat.NETCDF: open_netcdf,
DatasetFormat.ZARR: open_zarr,
DatasetFormat.KERCHUNK: open_kerchunk,
}


def open_dataset(source: DatasetSource, *, apply_fixes=True):
"""Open an xarray Dataset and optionally apply rook-native fixes."""
zarr_store = get_zarr_store(source)
if zarr_store:
ds = xr.open_zarr(zarr_store, **get_zarr_open_kwargs(zarr_store))
else:
open_kwargs = get_s3_open_kwargs(source)
paths = source.paths[0] if is_kerchunk_file(source.paths[0]) else list(source.paths)
ds = open_xr_dataset(paths, **open_kwargs)
opener = _OPENERS[detect_format(source)]
ds = opener(source, get_storage_options(source))

if apply_fixes and source.dataset_id:
ds = apply_dataset_fixes(source.dataset_id, ds)
Expand Down Expand Up @@ -89,21 +156,6 @@ def is_kerchunk_file(dset):
return path.endswith(KERCHUNK_EXTS)


def is_s3_uri(dset):
"""Return True when the input points to an S3 object URI."""
if isinstance(dset, Path):
dset = str(dset)

if not isinstance(dset, str):
return False

value = dset.strip()
if not value:
return False

return value.lower().startswith("s3://")


def is_zarr_store(dset):
"""Return True when the input looks like a Zarr store path."""
if isinstance(dset, Path):
Expand All @@ -120,40 +172,14 @@ def is_zarr_store(dset):
return path.endswith(ZARR_EXT)


def get_zarr_store(source: DatasetSource):
"""Return the store path when a source contains exactly one Zarr store."""
if len(source.paths) == 1 and is_zarr_store(source.paths[0]):
return source.paths[0]

return None


def get_zarr_open_kwargs(store):
"""Return xarray opener kwargs for a Zarr store."""
if not is_s3_uri(store):
return {}

storage_options = get_s3_storage_options()
if not storage_options:
return {}

return {"storage_options": storage_options}


def get_s3_open_kwargs(source: DatasetSource):
"""Return opener kwargs for S3-hosted NetCDF inputs."""
dset = source.paths[0]

if not is_s3_uri(dset) or is_kerchunk_file(dset) or is_zarr_store(dset):
return {}

storage_options = get_s3_storage_options()
if not storage_options:
return {}

return {"backend_kwargs": {"storage_options": storage_options}}


def get_s3_storage_options():
"""Return shared S3 transport options from central configuration."""
return config.get_s3_storage_options()
def _detect_path_transport(path: str) -> Transport:
scheme = urlsplit(path.strip()).scheme.lower()
if scheme in {"", "file"}:
return Transport.FILESYSTEM
if scheme in {"http", "https"}:
return Transport.HTTP
if scheme == "s3":
return Transport.S3
if scheme == "reference":
return Transport.REFERENCE
return Transport.OTHER
22 changes: 15 additions & 7 deletions src/rook/utils/ops/consolidate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,25 @@

from rook.catalog import get_catalog
from rook.io.datasets import (
DatasetFormat,
DatasetSource,
is_kerchunk_file,
is_s3_uri,
is_zarr_store,
Transport,
detect_format,
detect_transport,
)

from .helpers import wrap_sequence


def _bypasses_catalog(value):
"""Return whether a direct source should skip project resolution."""
source = DatasetSource(dataset_id=None, paths=value)
return (
detect_format(source) is not DatasetFormat.NETCDF
or detect_transport(source) is Transport.S3
)


def to_year(time_string):
"""Return the year in a time string as an integer."""
return int(time_string.split("-")[0])
Expand Down Expand Up @@ -86,9 +96,7 @@ def consolidate(collection, **kwargs):

if (
not isinstance(collection[0], FileMapper)
and not is_kerchunk_file(collection[0])
and not is_s3_uri(collection[0])
and not is_zarr_store(collection[0])
and not _bypasses_catalog(collection[0])
):
project = get_project_name(collection[0])
catalog = get_catalog(project)
Expand All @@ -98,7 +106,7 @@ def consolidate(collection, **kwargs):
time_param = kwargs.get("time")

for dset in collection:
if is_kerchunk_file(dset) or is_zarr_store(dset) or is_s3_uri(dset):
if not isinstance(dset, FileMapper) and _bypasses_catalog(dset):
sources.append(DatasetSource(dataset_id=None, paths=dset))

elif not catalog:
Expand Down
Loading