From 3da46a0e10c1bad8fe65ed1f1110063683f21ed6 Mon Sep 17 00:00:00 2001 From: Pascal Bourgault Date: Tue, 9 Sep 2025 11:51:34 -0400 Subject: [PATCH] Faster subset gridpoint --- HISTORY.rst | 1 + clisops/core/subset.py | 80 +++++++++++++++++----------------- clisops/utils/dataset_utils.py | 6 +++ tests/test_core_subset.py | 21 ++------- 4 files changed, 49 insertions(+), 59 deletions(-) diff --git a/HISTORY.rst b/HISTORY.rst index c92faa9b..5ee0ec8b 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -8,6 +8,7 @@ New Features ^^^^^^^^^^^^ * Added an `engine` argument to `Grid.ds.to_netcdf()` to allow users to specify the engine used for writing NetCDF files (#439). * Coding conventions have been updated to use Python 3.10+ features (#439). +* `core.subset.subset_gridpoint` will find nearest neighbours using a KDTree based on euclidean distance in lat/lon space instead of using great circle distances. The small loss in precision is compensated by a significant performance boost, especially for large grids and long point lists (#452). Bug Fixes ^^^^^^^^^ diff --git a/clisops/core/subset.py b/clisops/core/subset.py index c6fc9ac4..aa792267 100644 --- a/clisops/core/subset.py +++ b/clisops/core/subset.py @@ -16,6 +16,7 @@ from pyproj import Geod from pyproj.crs import CRS from pyproj.exceptions import CRSError +from scipy.spatial import KDTree from shapely import vectorized from shapely.geometry import LineString, MultiPolygon, Point, Polygon from shapely.ops import split, unary_union @@ -1526,8 +1527,7 @@ def subset_gridpoint( Extract one or more of the nearest gridpoint(s) from datarray based on lat lon coordinate(s). Return a subsetted data array (or Dataset) for the grid point(s) falling nearest the input longitude and latitude - coordinates. Optionally, subset the data array for years falling within provided date bounds. - Time series can optionally be subsetted by dates. + coordinates (as computed with a lat/lon euclidean distance). Time series can optionally be subsetted by dates. If 1D sequences of coordinates are given, the gridpoints will be concatenated along the new dimension "site". Parameters @@ -1576,55 +1576,51 @@ def subset_gridpoint( # Subset lat lon point prSub = subset_gridpoint(ds.pr, lon=-75, lat=45) - # Subset multiple variables in a single dataset - ds = xr.open_mfdataset([path_to_tasmax_file, path_to_tasmin_file]) - dsSub = subset_gridpoint(ds, lon=-75, lat=45) + # Drop locations where the closest gridpoint was too far (here 1000 km) + prSub = subset_gridpoint(ds.pr, lon=[-75, -60], lat=[45, 40], tolerance=1e6) """ if lat is None or lon is None: raise ValueError("Insufficient coordinates provided to locate grid point(s).") ptdim = lat.dims[0] + dist = None - lon_name = lon.name or "lon" - lat_name = lat.name or "lat" - + srclon = get_coord_by_type(da, "longitude", ignore_aux_coords=False) + srclat = get_coord_by_type(da, "latitude", ignore_aux_coords=False) # make sure input data has 'lon' and 'lat'(dims, coordinates, or data_vars) - if hasattr(da, lon_name) and hasattr(da, lat_name): - dims = list(da.dims) - - # if 'lon' and 'lat' are present as data dimensions use the .sel method. - if lat_name in dims and lon_name in dims: - da = da.sel(lat=lat, lon=lon, method="nearest") - - if tolerance is not None or add_distance: - # Calculate the geodesic distance between grid points and the point of interest. - dist = distance(da, lon=lon, lat=lat) - else: - dist = None - + if srclon is not None and srclat is not None: + srclon = da[srclon] + srclat = da[srclat] + if srclon.ndim == 1 and srclat.ndim == 1 and srclon.dims != srclat.dims: + # lon and lat are 1D and don't share coords : rectilinear grid + da = da.sel({srclat.dims[0]: lat, srclon.dims[0]: lon}, method="nearest") + elif srclon.ndim == 2 and srclat.dims == srclon.dims: + # lon and lat are 2D and share coords : curvilinear grid + pts = np.vstack([srclon.values.flatten(), srclat.values.flatten()]).T + # The input is a grid, so already well-behaved, no need for the precision-improving features of KDTree + tree = KDTree(pts, compact_nodes=False, balanced_tree=False) + _, idxs = tree.query(np.vstack([lon.values, lat.values]).T) + iY, iX = np.unravel_index(idxs, shape=da.lon.shape) + iY = lon.copy(data=iY) + iX = lon.copy(data=iX) + da = da.isel({da.lon.dims[0]: iY, da.lon.dims[1]: iX}) + elif srclon.ndim == 1 and srclat.ndim == 1 and srclon.dims == srclat.dims: + # lon and lat are 1D and share coords : list of points case + pts = np.vstack([srclon.values, srclat.values]).T + tree = KDTree(pts) + _, idxs = tree.query(np.vstack([lon.values, lat.values]).T) + idxs = lon.copy(data=idxs) + da = da.isel({srclon.dims[0]: idxs}) else: - # Calculate the geodesic distance between grid points and the point of interest. - dist = distance(da, lon=lon, lat=lat) - pts = [] - dists = [] - for site in dist[ptdim]: - # Find the indices for the closest point - inds = np.unravel_index(dist.sel({ptdim: site}).argmin(), dist.sel({ptdim: site}).shape) - - # Select data from closest point - args = {xydim: ind for xydim, ind in zip(dist.dims, inds, strict=False)} - pts.append(da.isel(**args)) - dists.append(dist.isel(**args)) - da = xarray.concat(pts, dim=ptdim) - dist = xarray.concat(dists, dim=ptdim) + raise ValueError(f"Unrecognized coordinate type for longitude and latitude ({srclon.name}, {srclat.name})") else: - raise ( - Exception( - f'{subset_gridpoint.__name__} requires input data with "lon" and "lat" coordinates or data variables.' - ) - ) + raise ValueError("subset_gridpoint requires input data with longitude and latitude coordinates.") + + if tolerance is not None or add_distance: + # Calculate the geodesic distance between grid points and the point of interest. + dist = distance(da, lon=lon, lat=lat) - if tolerance is not None and dist is not None: + if tolerance is not None: da = da.where(dist < tolerance) if add_distance: @@ -1632,6 +1628,8 @@ def subset_gridpoint( if len(lat) == 1: da = da.squeeze(ptdim) + else: + da = da.transpose(..., ptdim) if start_date or end_date: da = subset_time(da, start_date=start_date, end_date=end_date) diff --git a/clisops/utils/dataset_utils.py b/clisops/utils/dataset_utils.py index 620a0a30..98035b87 100644 --- a/clisops/utils/dataset_utils.py +++ b/clisops/utils/dataset_utils.py @@ -174,6 +174,9 @@ def is_latitude(coord: xr.DataArray | xr.Dataset) -> bool: if hasattr(coord, "long_name") and coord.long_name == "latitude": return True + if coord.name == "lat": + return True + return False @@ -203,6 +206,9 @@ def is_longitude(coord: xr.DataArray | xr.Dataset) -> bool: if hasattr(coord, "long_name") and coord.long_name == "longitude": return True + if coord.name == "lon": + return True + return False diff --git a/tests/test_core_subset.py b/tests/test_core_subset.py index a842e64a..328066fa 100644 --- a/tests/test_core_subset.py +++ b/tests/test_core_subset.py @@ -177,6 +177,7 @@ def test_dataset(self, nimbus): da = xr.open_mfdataset( [nimbus.fetch(self.nc_tasmax_file), nimbus.fetch(self.nc_tasmin_file)], combine="by_coords", + compat="override", ) lon = -72.4 lat = 46.1 @@ -222,15 +223,7 @@ def test_irregular(self, nimbus): # test_irregular transposed: da1 = xr.open_dataset(nimbus.fetch(self.nc_2dlonlat)).tasmax - dims = list(da1.dims) - dims.reverse() - daT = xr.DataArray(np.transpose(da1.values), dims=dims) - for d in daT.dims: - args = dict() - args[d] = da1[d] - daT = daT.assign_coords(**args) - daT = daT.assign_coords(lon=(["rlon", "rlat"], np.transpose(da1.lon.values))) - daT = daT.assign_coords(lat=(["rlon", "rlat"], np.transpose(da1.lat.values))) + daT = da1.transpose(*list(reversed(da1.dims))) out1 = subset.subset_gridpoint(daT, lon=lon, lat=lat) np.testing.assert_almost_equal(out1.lon, lon, 1) @@ -238,15 +231,7 @@ def test_irregular(self, nimbus): np.testing.assert_array_equal(out, out1) # Dataset with tasmax, lon and lat as data variables (i.e. lon, lat not coords of tasmax) - daT1 = xr.DataArray(np.transpose(da1.values), dims=dims) - for d in daT1.dims: - args = dict() - args[d] = da1[d] - daT1 = daT1.assign_coords(**args) - dsT = xr.Dataset(data_vars=None, coords=daT1.coords) - dsT["tasmax"] = daT1 - dsT["lon"] = xr.DataArray(np.transpose(da1.lon.values), dims=["rlon", "rlat"]) - dsT["lat"] = xr.DataArray(np.transpose(da1.lat.values), dims=["rlon", "rlat"]) + dsT = daT.to_dataset().reset_coords(["lon", "lat"]) out2 = subset.subset_gridpoint(dsT, lon=lon, lat=lat) np.testing.assert_almost_equal(out2.lon, lon, 1) np.testing.assert_almost_equal(out2.lat, lat, 1)