From 69357013a8e1156cb32f3433f21c708b09c9ccdc Mon Sep 17 00:00:00 2001 From: Ollie Tooth Date: Wed, 10 Jun 2026 15:49:37 +0100 Subject: [PATCH] Add apply_bbox() and apply_time_bounds() utility functions to generalise spatio-temporal subsetting in open_dataset() method. --- OceanDataStore/catalog/oceandatacatalog.py | 140 +++++++++++++++++++-- 1 file changed, 129 insertions(+), 11 deletions(-) diff --git a/OceanDataStore/catalog/oceandatacatalog.py b/OceanDataStore/catalog/oceandatacatalog.py index 659b412d..1203d24c 100644 --- a/OceanDataStore/catalog/oceandatacatalog.py +++ b/OceanDataStore/catalog/oceandatacatalog.py @@ -9,11 +9,12 @@ Authors: - Ollie Tooth """ +import os from typing import Optional -import os -import pystac import icechunk +import numpy as np +import pystac import xarray as xr # -- NOC brand CSS -- # @@ -170,6 +171,126 @@ """ +# -- Utility Functions -- # +def apply_bbox(ds: xr.Dataset, + bbox: tuple + ) -> xr.Dataset: + """ + Apply a geographical bounding box to subset an xarray Dataset. + + Parameters + ---------- + ds : xr.Dataset + Input xarray Dataset. + bbox : tuple + Geographical bounding box in the format (min_lon, max_lon, min_lat, max_lat). + + Returns + ------- + xr.Dataset + Geographically subsetted xarray Dataset. + """ + # -- Validate Inputs -- # + if not isinstance(ds, xr.Dataset): + raise ValueError("'ds' must be an xarray Dataset.") + if not (isinstance(bbox, tuple) and len(bbox) == 4): + raise ValueError("'bbox' must be a tuple of the form (min_lon, max_lon, min_lat, max_lat).") + + # -- Identify geographical coordinate names & dimensions -- # + # Default lat/lon coord names: + lon_name, lat_name = "nav_lon", "nav_lat" + # Update lat/lon coord names via standard_name attributes: + for coord in ds.coords: + if ds[coord].attrs.get('standard_name', '').lower() == 'longitude': + lon_name = coord + if ds[coord].attrs.get('standard_name', '').lower() == 'latitude': + lat_name = coord + + # -- Apply Bounding Box -- # + if (ds[lon_name].ndim > 1) and (ds[lat_name].ndim > 1): + # -- Case 1: 2D lat/lon coordinates -- # + # Identify lat/lon coordinate dimensions: + if ds[lon_name].dims != ds[lat_name].dims: + raise ValueError("Longitude and latitude coordinates must have the same dimensions.") + else: + y_name, x_name = ds[lon_name].dims + + # Define bbox mask: + mask = ( + (ds[lon_name] >= bbox[0]) + & (ds[lon_name] <= bbox[2]) + & (ds[lat_name] >= bbox[1]) + & (ds[lat_name] <= bbox[3]) + ) + + # Find rows/columns containing at least one valid grid point: + rows = mask.any(dim=x_name) + cols = mask.any(dim=y_name) + y_idx = np.where(rows.compute())[0] + x_idx = np.where(cols.compute())[0] + + if len(y_idx) == 0 or len(x_idx) == 0: + raise ValueError("No grid points found inside bbox") + + # Subset dataset to bounding box: + ds_subset = (ds + .where(mask, drop=False) + .isel({y_name: slice(y_idx.min(), y_idx.max() + 1), + x_name: slice(x_idx.min(), x_idx.max() + 1), + }) + ) + else: + # -- Case 2: 1D lat/lon coordinates -- # + ds_subset = ds.sel({lon_name: slice(bbox[0], bbox[1]), + lat_name: slice(bbox[2], bbox[3]) + }) + + return ds_subset + + +def apply_time_bounds(ds: xr.Dataset, + start_datetime: str | None = None, + end_datetime: str | None = None + ) -> xr.Dataset: + """ + Apply temporal subsetting to an xarray Dataset. + + Parameters + ---------- + ds : xr.Dataset + Input xarray Dataset. + start_datetime : str, optional + Start datetime in ISO format (e.g., 'YYYY-MM-DDTHH:MM:SS'). + end_datetime : str, optional + End datetime in ISO format (e.g., 'YYYY-MM-DDTHH:MM:SS'). + + Returns + ------- + xr.Dataset + Temporally subsetted xarray Dataset. + """ + # -- Validate Inputs -- # + if not isinstance(ds, xr.Dataset): + raise ValueError("'ds' must be an xarray Dataset.") + if start_datetime is not None: + if not isinstance(start_datetime, str): + raise ValueError("'start_datetime' must be a string in ISO format (e.g., 'YYYY-MM-DDTHH:MM:SS').") + if end_datetime is not None: + if not isinstance(end_datetime, str): + raise ValueError("'end_datetime' must be a string in ISO format (e.g., 'YYYY-MM-DDTHH:MM:SS').") + + # -- Identify time dimension -- # + for coord in ds.dims: + if 'time' in coord.lower(): + time_name = coord + break + + # -- Apply temporal subsetting -- # + ds_subset = ds.sel({time_name: slice(start_datetime, end_datetime)}) + + return ds_subset + + # -- Define CatalogSummary() class -- # class CatalogSummary: """ @@ -1007,7 +1128,7 @@ def open_dataset(self, variable_names: Optional[list[str]] = None, start_datetime: Optional[str] = None, end_datetime: Optional[str] = None, - bbox: Optional[tuple[float, float, float, float]] = None, + bbox: Optional[tuple[float | int, float | int, float | int, float | int]] = None, branch: str = "main", consolidated: bool = True, asset_key: Optional[str] = None @@ -1033,7 +1154,7 @@ def open_dataset(self, End datetime used to subset the dataset. Should be a string in ISO format (e.g., "2024-12-31T00:00:00Z"). Default is to use the Item end_datetime. - bbox : tuple[float, float, float, float], optional + bbox : tuple[float | int, float | int, float | int, float | int], optional Spatial bounding box used to subset the dataset. Should be a list of four floats representing the bounding box in the format: (min_lon, min_lat, max_lon, max_lat). Default is to use the Item bbox. @@ -1073,8 +1194,8 @@ def open_dataset(self, raise TypeError("'end_datetime' must be a string or None.") if not isinstance(bbox, (type(None), tuple)): raise TypeError("'bbox' must be a tuple or None.") - if bbox is not None and (len(bbox) != 4 or not all(isinstance(coord, float) for coord in bbox)): - raise TypeError("'bbox' must be a tuple of floats in the form (lon_min, lon_max, lat_min, lat_max).") + if bbox is not None and (len(bbox) != 4 or not all(isinstance(coord, (float, int)) for coord in bbox)): + raise TypeError("'bbox' must be a tuple of the form (min_lon, min_lat, max_lon, max_lat) with float or int values.") if not isinstance(branch, str): raise TypeError("'branch' must be a string.") if not isinstance(consolidated, bool): @@ -1123,12 +1244,9 @@ def open_dataset(self, # Spatio-temporal subsetting: if bbox: - lon = ds.nav_lon.load() - lat = ds.nav_lat.load() - ds = ds.where((lon >= bbox[0]) & (lon <= bbox[2]) & - (lat >= bbox[1]) & (lat <= bbox[3]), drop=True) + ds = apply_bbox(ds=ds, bbox=bbox) if start_datetime or end_datetime: - ds = ds.sel(time_counter=slice(start_datetime, end_datetime)) + ds = apply_time_bounds(ds=ds, start_datetime=start_datetime, end_datetime=end_datetime) return ds