Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 129 additions & 11 deletions OceanDataStore/catalog/oceandatacatalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
Authors:
- Ollie Tooth
"""
import os
from typing import Optional

import os
import pystac
import icechunk
import numpy as np
import pystac
import xarray as xr

# -- NOC brand CSS -- #
Expand Down Expand Up @@ -170,6 +171,126 @@
</style>
"""

# -- Utility Functions -- #
def apply_bbox(ds: xr.Dataset,
bbox: tuple
) -> xr.Dataset:
"""
Apply a geographical bounding box to subset an xarray Dataset.

Parameters
----------
ds : xr.Dataset
Input xarray Dataset.
bbox : tuple
Geographical bounding box in the format (min_lon, max_lon, min_lat, max_lat).

Returns
-------
xr.Dataset
Geographically subsetted xarray Dataset.
"""
# -- Validate Inputs -- #
if not isinstance(ds, xr.Dataset):
raise ValueError("'ds' must be an xarray Dataset.")
if not (isinstance(bbox, tuple) and len(bbox) == 4):
raise ValueError("'bbox' must be a tuple of the form (min_lon, max_lon, min_lat, max_lat).")

# -- Identify geographical coordinate names & dimensions -- #
# Default lat/lon coord names:
lon_name, lat_name = "nav_lon", "nav_lat"
# Update lat/lon coord names via standard_name attributes:
for coord in ds.coords:
if ds[coord].attrs.get('standard_name', '').lower() == 'longitude':
lon_name = coord
if ds[coord].attrs.get('standard_name', '').lower() == 'latitude':
lat_name = coord

# -- Apply Bounding Box -- #
if (ds[lon_name].ndim > 1) and (ds[lat_name].ndim > 1):
# -- Case 1: 2D lat/lon coordinates -- #
# Identify lat/lon coordinate dimensions:
if ds[lon_name].dims != ds[lat_name].dims:
raise ValueError("Longitude and latitude coordinates must have the same dimensions.")
else:
y_name, x_name = ds[lon_name].dims

# Define bbox mask:
mask = (
(ds[lon_name] >= bbox[0])
& (ds[lon_name] <= bbox[2])
& (ds[lat_name] >= bbox[1])
& (ds[lat_name] <= bbox[3])
)

# Find rows/columns containing at least one valid grid point:
rows = mask.any(dim=x_name)
cols = mask.any(dim=y_name)
y_idx = np.where(rows.compute())[0]
x_idx = np.where(cols.compute())[0]

if len(y_idx) == 0 or len(x_idx) == 0:
raise ValueError("No grid points found inside bbox")

# Subset dataset to bounding box:
ds_subset = (ds
.where(mask, drop=False)
.isel({y_name: slice(y_idx.min(), y_idx.max() + 1),
x_name: slice(x_idx.min(), x_idx.max() + 1),
})
)
else:
# -- Case 2: 1D lat/lon coordinates -- #
ds_subset = ds.sel({lon_name: slice(bbox[0], bbox[1]),
lat_name: slice(bbox[2], bbox[3])
})

return ds_subset


def apply_time_bounds(ds: xr.Dataset,
start_datetime: str | None = None,
end_datetime: str | None = None
) -> xr.Dataset:
"""
Apply temporal subsetting to an xarray Dataset.

Parameters
----------
ds : xr.Dataset
Input xarray Dataset.
start_datetime : str, optional
Start datetime in ISO format (e.g., 'YYYY-MM-DDTHH:MM:SS').
end_datetime : str, optional
End datetime in ISO format (e.g., 'YYYY-MM-DDTHH:MM:SS').

Returns
-------
xr.Dataset
Temporally subsetted xarray Dataset.
"""
# -- Validate Inputs -- #
if not isinstance(ds, xr.Dataset):
raise ValueError("'ds' must be an xarray Dataset.")
if start_datetime is not None:
if not isinstance(start_datetime, str):
raise ValueError("'start_datetime' must be a string in ISO format (e.g., 'YYYY-MM-DDTHH:MM:SS').")
if end_datetime is not None:
if not isinstance(end_datetime, str):
raise ValueError("'end_datetime' must be a string in ISO format (e.g., 'YYYY-MM-DDTHH:MM:SS').")

# -- Identify time dimension -- #
for coord in ds.dims:
if 'time' in coord.lower():
time_name = coord
break

# -- Apply temporal subsetting -- #
ds_subset = ds.sel({time_name: slice(start_datetime, end_datetime)})

return ds_subset


# -- Define CatalogSummary() class -- #
class CatalogSummary:
"""
Expand Down Expand Up @@ -1007,7 +1128,7 @@ def open_dataset(self,
variable_names: Optional[list[str]] = None,
start_datetime: Optional[str] = None,
end_datetime: Optional[str] = None,
bbox: Optional[tuple[float, float, float, float]] = None,
bbox: Optional[tuple[float | int, float | int, float | int, float | int]] = None,
branch: str = "main",
consolidated: bool = True,
asset_key: Optional[str] = None
Expand All @@ -1033,7 +1154,7 @@ def open_dataset(self,
End datetime used to subset the dataset. Should be a string
in ISO format (e.g., "2024-12-31T00:00:00Z"). Default is to use
the Item end_datetime.
bbox : tuple[float, float, float, float], optional
bbox : tuple[float | int, float | int, float | int, float | int], optional
Spatial bounding box used to subset the dataset. Should be a list of four floats
representing the bounding box in the format: (min_lon, min_lat, max_lon, max_lat).
Default is to use the Item bbox.
Expand Down Expand Up @@ -1073,8 +1194,8 @@ def open_dataset(self,
raise TypeError("'end_datetime' must be a string or None.")
if not isinstance(bbox, (type(None), tuple)):
raise TypeError("'bbox' must be a tuple or None.")
if bbox is not None and (len(bbox) != 4 or not all(isinstance(coord, float) for coord in bbox)):
raise TypeError("'bbox' must be a tuple of floats in the form (lon_min, lon_max, lat_min, lat_max).")
if bbox is not None and (len(bbox) != 4 or not all(isinstance(coord, (float, int)) for coord in bbox)):
raise TypeError("'bbox' must be a tuple of the form (min_lon, min_lat, max_lon, max_lat) with float or int values.")
if not isinstance(branch, str):
raise TypeError("'branch' must be a string.")
if not isinstance(consolidated, bool):
Expand Down Expand Up @@ -1123,12 +1244,9 @@ def open_dataset(self,

# Spatio-temporal subsetting:
if bbox:
lon = ds.nav_lon.load()
lat = ds.nav_lat.load()
ds = ds.where((lon >= bbox[0]) & (lon <= bbox[2]) &
(lat >= bbox[1]) & (lat <= bbox[3]), drop=True)
ds = apply_bbox(ds=ds, bbox=bbox)

if start_datetime or end_datetime:
ds = ds.sel(time_counter=slice(start_datetime, end_datetime))
ds = apply_time_bounds(ds=ds, start_datetime=start_datetime, end_datetime=end_datetime)

return ds
Loading