From a41c8e5cf24a4b22453e565ff2d6f7b3f494093e Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Thu, 5 Jun 2025 18:10:41 +0200 Subject: [PATCH 01/24] Add helpers to check Field data on init --- parcels/xgrid.py | 53 +++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 52 insertions(+), 1 deletion(-) diff --git a/parcels/xgrid.py b/parcels/xgrid.py index 3948d82871..c160d49d84 100644 --- a/parcels/xgrid.py +++ b/parcels/xgrid.py @@ -1,12 +1,17 @@ -from typing import Literal +from collections.abc import Hashable +from typing import Literal, cast import numpy as np import numpy.typing as npt +import xarray as xr from parcels import xgcm from parcels.basegrid import BaseGrid from parcels.tools.converters import TimeConverter +_AXIS_DIRECTION = Literal["X", "Y", "Z", "T"] +_AXIS_POSITION = Literal["center", "left", "right", "inner", "outer"] + def get_dimensionality(axis: xgcm.Axis | None) -> int: if axis is None: @@ -165,3 +170,49 @@ def _gtype(self): return GridType.CurvilinearSGrid def search(self, z, y, x, ei=None, search2D=False): ... + + +def get_direction_axis(grid: xgcm.Grid, var: str) -> _AXIS_DIRECTION | None: + """For a given variable name in a grid, returns the direction axis it is on.""" + for direction, axis in grid.axes.items(): + if var in axis.coords.values(): + return direction + return None + + +def get_position(grid: xgcm.Grid, var: str) -> _AXIS_POSITION | None: + """For a given variable, returns the position of the variable in the grid.""" + for axis in grid.axes.values(): + var_to_position = {var: position for position, var in axis.coords.items()} + + if var in var_to_position: + return var_to_position[var] + return None + + +def assert_valid_field_array_ordering(da: xr.DataArray, grid: xgcm.Grid): + # ? This works well for one file, but what happens if the Field and Grid are stored in different files? + dim_direction = {dim: get_direction_axis(grid, dim) for dim in da.dims} + + if None in dim_direction.values(): + for dim, direction in dim_direction.items(): + if direction is None: + raise ValueError( + f"Dimension {dim!r} for DataArray {da.name!r} with dims {da.dims} is not associated with a direction on the provided grid." + ) + + dim_direction = cast(dict[Hashable, _AXIS_DIRECTION], dim_direction) + + # Assert all dimensions are present + if set(dim_direction.values()) != {"T", "Z", "Y", "X"}: + raise ValueError( + f"DataArray {da.name!r} with dims {da.dims} has directions {tuple(dim_direction.values())}." + "Expected directions of 'T', 'Z', 'Y', and 'X'." + ) + + # Assert order is t, z, y, x + # ? Is this even necessary? Can't we just fetch the direction from the grid? + if list(dim_direction.values()) != ["T", "Z", "Y", "X"]: + raise ValueError( + f"Dimension order for array {da.name!r} is not valid. Got {tuple(dim_direction.keys())} with associated directions of {tuple(dim_direction.values())}. Expected directions of ('T', 'Z', 'Y', 'X'). Transpose your array accordingly." + ) From c427da22b69b1664dc93a69f2b31674056e8dc02 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Thu, 12 Jun 2025 11:11:22 +0200 Subject: [PATCH 02/24] Update helper functions --- parcels/xgrid.py | 73 +++++++++++++++++++++++++++++++----------------- 1 file changed, 47 insertions(+), 26 deletions(-) diff --git a/parcels/xgrid.py b/parcels/xgrid.py index c160d49d84..611ad291c5 100644 --- a/parcels/xgrid.py +++ b/parcels/xgrid.py @@ -1,4 +1,4 @@ -from collections.abc import Hashable +from collections.abc import Hashable, Mapping from typing import Literal, cast import numpy as np @@ -10,7 +10,9 @@ from parcels.tools.converters import TimeConverter _AXIS_DIRECTION = Literal["X", "Y", "Z", "T"] +_AXIS_DIRECTION_SPATIAL = Literal["X", "Y", "Z"] _AXIS_POSITION = Literal["center", "left", "right", "inner", "outer"] +_XGCM_AXES = Mapping[_AXIS_DIRECTION, xgcm.Axis] def get_dimensionality(axis: xgcm.Axis | None) -> int: @@ -172,47 +174,66 @@ def _gtype(self): def search(self, z, y, x, ei=None, search2D=False): ... -def get_direction_axis(grid: xgcm.Grid, var: str) -> _AXIS_DIRECTION | None: - """For a given variable name in a grid, returns the direction axis it is on.""" - for direction, axis in grid.axes.items(): - if var in axis.coords.values(): - return direction +def get_axis_from_dim_name(axes: _XGCM_AXES, dim: str) -> _AXIS_DIRECTION | None: + """For a given dimension name in a grid, returns the direction axis it is on.""" + for axis_name, axis in axes.items(): + if dim in axis.coords.values(): + return axis_name return None -def get_position(grid: xgcm.Grid, var: str) -> _AXIS_POSITION | None: - """For a given variable, returns the position of the variable in the grid.""" - for axis in grid.axes.values(): +def get_position_from_dim_name(axes: _XGCM_AXES, dim: str) -> _AXIS_POSITION | None: + """For a given dimension, returns the position of the variable in the grid.""" + for axis in axes.values(): var_to_position = {var: position for position, var in axis.coords.items()} - if var in var_to_position: - return var_to_position[var] + if dim in var_to_position: + return var_to_position[dim] return None -def assert_valid_field_array_ordering(da: xr.DataArray, grid: xgcm.Grid): - # ? This works well for one file, but what happens if the Field and Grid are stored in different files? - dim_direction = {dim: get_direction_axis(grid, dim) for dim in da.dims} +def assert_valid_field_array(da: xr.DataArray, axes: _XGCM_AXES): + """ + Asserts that for a data array: + - All dimensions are associated with a direction on the grid + - These directions are T, Z, Y, X and the array is ordered as T, Z, Y, X + """ + dim_to_axis = {dim: get_axis_from_dim_name(axes, dim) for dim in da.dims} - if None in dim_direction.values(): - for dim, direction in dim_direction.items(): - if direction is None: - raise ValueError( - f"Dimension {dim!r} for DataArray {da.name!r} with dims {da.dims} is not associated with a direction on the provided grid." - ) + for dim, direction in dim_to_axis.items(): + if direction is None: + raise ValueError( + f"Dimension {dim!r} for DataArray {da.name!r} with dims {da.dims} is not associated with a direction on the provided grid." + ) - dim_direction = cast(dict[Hashable, _AXIS_DIRECTION], dim_direction) + dim_to_axis = cast(dict[Hashable, _AXIS_DIRECTION], dim_to_axis) # Assert all dimensions are present - if set(dim_direction.values()) != {"T", "Z", "Y", "X"}: + if set(dim_to_axis.values()) != {"T", "Z", "Y", "X"}: raise ValueError( - f"DataArray {da.name!r} with dims {da.dims} has directions {tuple(dim_direction.values())}." + f"DataArray {da.name!r} with dims {da.dims} has directions {tuple(dim_to_axis.values())}." "Expected directions of 'T', 'Z', 'Y', and 'X'." ) # Assert order is t, z, y, x - # ? Is this even necessary? Can't we just fetch the direction from the grid? - if list(dim_direction.values()) != ["T", "Z", "Y", "X"]: + if list(dim_to_axis.values()) != ["T", "Z", "Y", "X"]: raise ValueError( - f"Dimension order for array {da.name!r} is not valid. Got {tuple(dim_direction.keys())} with associated directions of {tuple(dim_direction.values())}. Expected directions of ('T', 'Z', 'Y', 'X'). Transpose your array accordingly." + f"Dimension order for array {da.name!r} is not valid. Got {tuple(dim_to_axis.keys())} with associated directions of {tuple(dim_to_axis.values())}. Expected directions of ('T', 'Z', 'Y', 'X'). Transpose your array accordingly." ) + + +def assert_valid_lon_lat(da_lon, da_lat, axes: _XGCM_AXES): + """ + Asserts that the provided longitude and latitude DataArrays are defined appropriately + on the F points to match the internal representation in Parcels. + + - Longitude and latitude must be 1D or 2D + - Both are defined on the left points (i.e., not the centers) + - If 1D: + - Longitude is associated with the X axis + - Latitude is associated with the Y axis + - If 2D: + - Lon and lat are defined on the same dimensions + - Lon and lat are transposed such they're Y, X + """ + ... From 8d6cc6cd68fe8666d55d40b05ec44a9ebaa1273b Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Thu, 12 Jun 2025 11:27:42 +0200 Subject: [PATCH 03/24] Add assert_valid_lon_lat to xgrid.py and update helpers --- parcels/xgrid.py | 63 ++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 55 insertions(+), 8 deletions(-) diff --git a/parcels/xgrid.py b/parcels/xgrid.py index 611ad291c5..d17e94fcbc 100644 --- a/parcels/xgrid.py +++ b/parcels/xgrid.py @@ -192,12 +192,7 @@ def get_position_from_dim_name(axes: _XGCM_AXES, dim: str) -> _AXIS_POSITION | N return None -def assert_valid_field_array(da: xr.DataArray, axes: _XGCM_AXES): - """ - Asserts that for a data array: - - All dimensions are associated with a direction on the grid - - These directions are T, Z, Y, X and the array is ordered as T, Z, Y, X - """ +def assert_all_dimensions_correspond_with_axis(da: xr.DataArray, axes: _XGCM_AXES) -> None: dim_to_axis = {dim: get_axis_from_dim_name(axes, dim) for dim in da.dims} for dim, direction in dim_to_axis.items(): @@ -206,6 +201,16 @@ def assert_valid_field_array(da: xr.DataArray, axes: _XGCM_AXES): f"Dimension {dim!r} for DataArray {da.name!r} with dims {da.dims} is not associated with a direction on the provided grid." ) + +def assert_valid_field_array(da: xr.DataArray, axes: _XGCM_AXES): + """ + Asserts that for a data array: + - All dimensions are associated with a direction on the grid + - These directions are T, Z, Y, X and the array is ordered as T, Z, Y, X + """ + assert_all_dimensions_correspond_with_axis(da, axes) + + dim_to_axis = {dim: get_axis_from_dim_name(axes, dim) for dim in da.dims} dim_to_axis = cast(dict[Hashable, _AXIS_DIRECTION], dim_to_axis) # Assert all dimensions are present @@ -227,7 +232,7 @@ def assert_valid_lon_lat(da_lon, da_lat, axes: _XGCM_AXES): Asserts that the provided longitude and latitude DataArrays are defined appropriately on the F points to match the internal representation in Parcels. - - Longitude and latitude must be 1D or 2D + - Longitude and latitude must be 1D or 2D (both must have the same dimensionality) - Both are defined on the left points (i.e., not the centers) - If 1D: - Longitude is associated with the X axis @@ -236,4 +241,46 @@ def assert_valid_lon_lat(da_lon, da_lat, axes: _XGCM_AXES): - Lon and lat are defined on the same dimensions - Lon and lat are transposed such they're Y, X """ - ... + assert_all_dimensions_correspond_with_axis(da_lon, axes) + assert_all_dimensions_correspond_with_axis(da_lat, axes) + + dim_to_position = {dim: get_position_from_dim_name(axes, dim) for dim in da_lon.dims} + dim_to_position.update({dim: get_position_from_dim_name(axes, dim) for dim in da_lat.dims}) + + for dim in da_lon.dims: + if get_position_from_dim_name(axes, dim) == "center": + raise ValueError( + f"Longitude DataArray {da_lon.name!r} with dims {da_lon.dims} is defined on the center of the grid, but must be defined on the F points." + ) + for dim in da_lat.dims: + if get_position_from_dim_name(axes, dim) == "center": + raise ValueError( + f"Latitude DataArray {da_lat.name!r} with dims {da_lat.dims} is defined on the center of the grid, but must be defined on the F points." + ) + + if da_lon.ndim != da_lat.ndim: + raise ValueError( + f"Longitude DataArray {da_lon.name!r} with dims {da_lon.dims} and Latitude DataArray {da_lat.name!r} with dims {da_lat.dims} have different dimensionalities." + ) + if da_lon.ndim not in (1, 2): + raise ValueError( + f"Longitude DataArray {da_lon.name!r} with dims {da_lon.dims} and Latitude DataArray {da_lat.name!r} with dims {da_lat.dims} must be 1D or 2D." + ) + + if da_lon.ndim == 1: + if get_axis_from_dim_name(da_lon.dims[0]) != "X": + raise ValueError( + f"Longitude DataArray {da_lon.name!r} with dims {da_lon.dims} is not associated with the X axis." + ) + if get_axis_from_dim_name(da_lat.dims[0]) != "Y": + raise ValueError( + f"Latitude DataArray {da_lat.name!r} with dims {da_lat.dims} is not associated with the Y axis." + ) + + if da_lon.ndim == 2: + lon_axes = [get_axis_from_dim_name(dim) for dim in da_lon.dims] + lat_axes = [get_axis_from_dim_name(dim) for dim in da_lat.dims] + if lon_axes != lat_axes != ["Y", "X"]: + raise ValueError( + f"Longitude DataArray {da_lon.name!r} with dims {da_lon.dims} and Latitude DataArray {da_lat.name!r} with dims {da_lat.dims} must be defined on the X and Y axes and transposed to have dimensions in order of Y, X." + ) From db0549dcb24cf1d1c7994994237bf314703c6bb3 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Thu, 12 Jun 2025 11:42:49 +0200 Subject: [PATCH 04/24] Add testing stubs --- tests/v4/test_xgrid.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/tests/v4/test_xgrid.py b/tests/v4/test_xgrid.py index a1c0d9472d..5638432acf 100644 --- a/tests/v4/test_xgrid.py +++ b/tests/v4/test_xgrid.py @@ -8,7 +8,9 @@ from parcels._datasets.structured.generic import T, X, Y, Z, datasets from parcels.grid import Grid as OldGrid from parcels.tools.converters import TimeConverter -from parcels.xgrid import XGrid +from parcels.xgrid import ( + XGrid, +) GridTestCase = namedtuple("GridTestCase", ["Grid", "attr", "expected"]) @@ -74,3 +76,13 @@ def test_xgrid_against_old(ds, attr): actual = getattr(grid, attr) expected = getattr(old_grid, attr) assert_equal(actual, expected) + + +def test_invalid_xgrid_field_array(ds): + """Stress test initialiser by creating incompatible datasets that test the edge cases""" + ... + + +def test_invalid_lon_lat(ds): + """Stress test the grid initialiser by creating incompatible datasets that test the edge cases""" + ... From f3edf12c12c75df2144f4666695dc4fd86c75703 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Thu, 12 Jun 2025 13:36:17 +0200 Subject: [PATCH 05/24] Bugfix assert_valid_lon_lat and add to xgrid init --- parcels/xgrid.py | 10 ++++++---- tests/v4/test_xgrid.py | 7 ++++++- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/parcels/xgrid.py b/parcels/xgrid.py index d17e94fcbc..4faef9cd44 100644 --- a/parcels/xgrid.py +++ b/parcels/xgrid.py @@ -47,6 +47,8 @@ class XGrid(BaseGrid): def __init__(self, grid: xgcm.Grid, mesh="flat"): self.xgcm_grid = grid self.mesh = mesh + ds = grid._ds + assert_valid_lon_lat(ds["lon"], ds["lat"], grid.axes) # ! Not ideal... Triggers computation on a throwaway item. Keeping for now for v3 compat, will be removed in v4. self.lonlat_minmax = np.array( @@ -268,18 +270,18 @@ def assert_valid_lon_lat(da_lon, da_lat, axes: _XGCM_AXES): ) if da_lon.ndim == 1: - if get_axis_from_dim_name(da_lon.dims[0]) != "X": + if get_axis_from_dim_name(axes, da_lon.dims[0]) != "X": raise ValueError( f"Longitude DataArray {da_lon.name!r} with dims {da_lon.dims} is not associated with the X axis." ) - if get_axis_from_dim_name(da_lat.dims[0]) != "Y": + if get_axis_from_dim_name(axes, da_lat.dims[0]) != "Y": raise ValueError( f"Latitude DataArray {da_lat.name!r} with dims {da_lat.dims} is not associated with the Y axis." ) if da_lon.ndim == 2: - lon_axes = [get_axis_from_dim_name(dim) for dim in da_lon.dims] - lat_axes = [get_axis_from_dim_name(dim) for dim in da_lat.dims] + lon_axes = [get_axis_from_dim_name(axes, dim) for dim in da_lon.dims] + lat_axes = [get_axis_from_dim_name(axes, dim) for dim in da_lat.dims] if lon_axes != lat_axes != ["Y", "X"]: raise ValueError( f"Longitude DataArray {da_lon.name!r} with dims {da_lon.dims} and Latitude DataArray {da_lat.name!r} with dims {da_lat.dims} must be defined on the X and Y axes and transposed to have dimensions in order of Y, X." diff --git a/tests/v4/test_xgrid.py b/tests/v4/test_xgrid.py index 5638432acf..4c7794bdcb 100644 --- a/tests/v4/test_xgrid.py +++ b/tests/v4/test_xgrid.py @@ -61,7 +61,7 @@ def test_xgrid_properties_ground_truth(ds, attr, expected): "_gtype", ], ) -@pytest.mark.parametrize("ds", datasets.values()) +@pytest.mark.parametrize("ds", [pytest.param(ds, id=key) for key, ds in datasets.items()]) def test_xgrid_against_old(ds, attr): grid = XGrid(xgcm.Grid(ds, periodic=False)) @@ -78,6 +78,11 @@ def test_xgrid_against_old(ds, attr): assert_equal(actual, expected) +@pytest.mark.parametrize("ds", [pytest.param(ds, id=key) for key, ds in datasets.items()]) +def test_grid_init_on_generic_datasets(ds): + XGrid(xgcm.Grid(ds, periodic=False)) + + def test_invalid_xgrid_field_array(ds): """Stress test initialiser by creating incompatible datasets that test the edge cases""" ... From 106236e4940cd7728ec6c3202dc533d39902428d Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Thu, 12 Jun 2025 13:51:00 +0200 Subject: [PATCH 06/24] Add test_invalid_xgrid_field_array --- parcels/xgrid.py | 8 ++++++-- tests/v4/test_xgrid.py | 30 ++++++++++++++++++++++++++++-- 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/parcels/xgrid.py b/parcels/xgrid.py index 4faef9cd44..72efbb6ccf 100644 --- a/parcels/xgrid.py +++ b/parcels/xgrid.py @@ -280,9 +280,13 @@ def assert_valid_lon_lat(da_lon, da_lat, axes: _XGCM_AXES): ) if da_lon.ndim == 2: + if da_lon.dims != da_lat.dims: + raise ValueError( + f"Longitude DataArray {da_lon.name!r} with dims {da_lon.dims} and Latitude DataArray {da_lat.name!r} with dims {da_lat.dims} must be defined on the same dimensions." + ) + lon_axes = [get_axis_from_dim_name(axes, dim) for dim in da_lon.dims] - lat_axes = [get_axis_from_dim_name(axes, dim) for dim in da_lat.dims] - if lon_axes != lat_axes != ["Y", "X"]: + if lon_axes != ["Y", "X"]: raise ValueError( f"Longitude DataArray {da_lon.name!r} with dims {da_lon.dims} and Latitude DataArray {da_lat.name!r} with dims {da_lat.dims} must be defined on the X and Y axes and transposed to have dimensions in order of Y, X." ) diff --git a/tests/v4/test_xgrid.py b/tests/v4/test_xgrid.py index 4c7794bdcb..6d30e78dc7 100644 --- a/tests/v4/test_xgrid.py +++ b/tests/v4/test_xgrid.py @@ -2,6 +2,7 @@ import numpy as np import pytest +import xarray as xr from numpy.testing import assert_allclose from parcels import xgcm @@ -83,9 +84,34 @@ def test_grid_init_on_generic_datasets(ds): XGrid(xgcm.Grid(ds, periodic=False)) -def test_invalid_xgrid_field_array(ds): +def test_invalid_xgrid_field_array(): """Stress test initialiser by creating incompatible datasets that test the edge cases""" - ... + ds = datasets["ds_2d_left"].copy() + ds["lon"], ds["lat"] = xr.broadcast(ds["YC"], ds["XC"]) + + with pytest.raises( + ValueError, + match=".*is defined on the center of the grid, but must be defined on the F points\.", + ): + XGrid(xgcm.Grid(ds, periodic=False)) + + ds = datasets["ds_2d_left"].copy() + ds["lon"], _ = xr.broadcast(ds["YG"], ds["XG"]) + with pytest.raises( + ValueError, + match=".*have different dimensionalities\.", + ): + XGrid(xgcm.Grid(ds, periodic=False)) + + ds = datasets["ds_2d_left"].copy() + ds["lon"], ds["lat"] = xr.broadcast(ds["YG"], ds["XG"]) + ds["lon"], ds["lat"] = ds["lon"].transpose(), ds["lat"].transpose() + + with pytest.raises( + ValueError, + match=".*must be defined on the X and Y axes and transposed to have dimensions in order of Y, X\.", + ): + XGrid(xgcm.Grid(ds, periodic=False)) def test_invalid_lon_lat(ds): From feb4a7c5e6e0e7b3c2e1d0577bbf8c1bd49cbe23 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Thu, 12 Jun 2025 14:20:37 +0200 Subject: [PATCH 07/24] Add _iterate_over_cells helper for 2D lon and lat arrays --- parcels/xgrid.py | 37 ++++++++++++++++++++++++++++++++++ tests/v4/test_xgrid.py | 45 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 82 insertions(+) diff --git a/parcels/xgrid.py b/parcels/xgrid.py index 72efbb6ccf..3354f2693b 100644 --- a/parcels/xgrid.py +++ b/parcels/xgrid.py @@ -290,3 +290,40 @@ def assert_valid_lon_lat(da_lon, da_lat, axes: _XGCM_AXES): raise ValueError( f"Longitude DataArray {da_lon.name!r} with dims {da_lon.dims} and Latitude DataArray {da_lat.name!r} with dims {da_lat.dims} must be defined on the X and Y axes and transposed to have dimensions in order of Y, X." ) + + +def _iterate_over_cells( + *, + lat: np.ndarray, + lon: np.ndarray, +): + """ + Takes in 2 2D arrays representing the F points of a grid and returns a generator that yields + the individual cells of the grid as 2D arrays of the corners of the cells. + + Parameters + ---------- + lon : np.ndarray + 2D array of the longitude F points of the grid of shape (Y, X) + lat : np.ndarray + 2D array of the latitude F points of the grid of shape (Y, X). + + Yields + ------ + np.ndarray + 2D array of shape (4, 2) representing the corners of the cell in the order: + bottom left, bottom right, top right, top left. Output is provided in lat, lon order. + """ + assert lon.ndim == 2 and lat.ndim == 2, "lon and lat must be 2D arrays." + assert lon.shape == lat.shape, "lon and lat must have the same shape." + + for y in range(lon.shape[0] - 1): + for x in range(lon.shape[1] - 1): + yield np.array( + [ + [lat[y, x], lon[y, x]], # bottom left + [lat[y, x + 1], lon[y, x + 1]], # bottom right + [lat[y + 1, x + 1], lon[y + 1, x + 1]], # top right + [lat[y + 1, x], lon[y + 1, x]], # top left + ] + ) diff --git a/tests/v4/test_xgrid.py b/tests/v4/test_xgrid.py index 6d30e78dc7..54af342164 100644 --- a/tests/v4/test_xgrid.py +++ b/tests/v4/test_xgrid.py @@ -1,4 +1,5 @@ from collections import namedtuple +from typing import Literal import numpy as np import pytest @@ -11,6 +12,7 @@ from parcels.tools.converters import TimeConverter from parcels.xgrid import ( XGrid, + _iterate_over_cells, ) GridTestCase = namedtuple("GridTestCase", ["Grid", "attr", "expected"]) @@ -117,3 +119,46 @@ def test_invalid_xgrid_field_array(): def test_invalid_lon_lat(ds): """Stress test the grid initialiser by creating incompatible datasets that test the edge cases""" ... + + +def test_iterate_over_cells(): + ny = 3 # Number of cells in the y-direction + nx = 6 # Number of cells in the x-direction + lon = np.arange(ny + 1) + lat = np.arange(nx + 1) + LAT, LON = np.meshgrid(lat, lon, indexing="ij") + + # Call the function and collect the output + cells = list(_iterate_over_cells(lat=LAT, lon=LON)) + assert len(cells) == ny * nx, "Number of cells does not match expected." + + for cell in cells: + _assert_point_is("east", 1, cell[0], cell[1]) + _assert_point_is("north", 1, cell[1], cell[2]) + _assert_point_is("west", 1, cell[2], cell[3]) + + +def test__assert_point_is(): + _assert_point_is("east", 1, np.array([0, 0]), np.array([0, 1])) + _assert_point_is("west", 1, np.array([0, 1]), np.array([0, 0])) + _assert_point_is("north", 1, np.array([0, 0]), np.array([1, 0])) + _assert_point_is("south", 1, np.array([1, 0]), np.array([0, 0])) + + +def _assert_point_is( + direction: Literal["east", "west", "north", "south"], by: int, reference_cell: np.ndarray, test_cell: np.ndarray +): + """cell1 and cell2 are arrays of (lat, lon)""" + match direction: + case "east": + delta = np.array([0, by]) + case "west": + delta = np.array([0, -by]) + case "north": + delta = np.array([by, 0]) + case "south": + delta = np.array([-by, 0]) + case _: + raise ValueError(f"Invalid method: {direction}") + + np.testing.assert_allclose(reference_cell + delta, test_cell) From 36e4aa79d56c728da0b0121a27ee2ef87385e8ee Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Thu, 12 Jun 2025 14:32:10 +0200 Subject: [PATCH 08/24] Fix test name --- tests/v4/test_xgrid.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/v4/test_xgrid.py b/tests/v4/test_xgrid.py index 54af342164..3d7aa24059 100644 --- a/tests/v4/test_xgrid.py +++ b/tests/v4/test_xgrid.py @@ -88,6 +88,11 @@ def test_grid_init_on_generic_datasets(ds): def test_invalid_xgrid_field_array(): """Stress test initialiser by creating incompatible datasets that test the edge cases""" + ... + + +def test_invalid_lon_lat(ds): + """Stress test the grid initialiser by creating incompatible datasets that test the edge cases""" ds = datasets["ds_2d_left"].copy() ds["lon"], ds["lat"] = xr.broadcast(ds["YC"], ds["XC"]) @@ -116,11 +121,6 @@ def test_invalid_xgrid_field_array(): XGrid(xgcm.Grid(ds, periodic=False)) -def test_invalid_lon_lat(ds): - """Stress test the grid initialiser by creating incompatible datasets that test the edge cases""" - ... - - def test_iterate_over_cells(): ny = 3 # Number of cells in the y-direction nx = 6 # Number of cells in the x-direction From 0861541097befaf0799521498e9d1220dcb8910b Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Thu, 12 Jun 2025 14:32:49 +0200 Subject: [PATCH 09/24] Add docstrings to XGrid xdim, ydim, zdim attributes --- parcels/xgrid.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/parcels/xgrid.py b/parcels/xgrid.py index 3354f2693b..cb93c2ec6f 100644 --- a/parcels/xgrid.py +++ b/parcels/xgrid.py @@ -15,7 +15,7 @@ _XGCM_AXES = Mapping[_AXIS_DIRECTION, xgcm.Axis] -def get_dimensionality(axis: xgcm.Axis | None) -> int: +def get_tracer_dimensionality(axis: xgcm.Axis | None) -> int: if axis is None: return 1 first_coord = list(axis.coords.items())[0] @@ -116,19 +116,22 @@ def time(self): @property def xdim(self): - return get_dimensionality(self.xgcm_grid.axes.get("X")) + """Number of T (tracer) cells in the X direction.""" + return get_tracer_dimensionality(self.xgcm_grid.axes.get("X")) @property def ydim(self): - return get_dimensionality(self.xgcm_grid.axes.get("Y")) + """Number of T (tracer) cells in the Y direction.""" + return get_tracer_dimensionality(self.xgcm_grid.axes.get("Y")) @property def zdim(self): - return get_dimensionality(self.xgcm_grid.axes.get("Z")) + """Number of T (tracer) cells in the Z direction.""" + return get_tracer_dimensionality(self.xgcm_grid.axes.get("Z")) @property def tdim(self): - return get_dimensionality(self.xgcm_grid.axes.get("T")) + return get_tracer_dimensionality(self.xgcm_grid.axes.get("T")) @property def time_origin(self): From c03a8cf5d6737d1d1c2f993af3bcb02b42da8ca7 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Thu, 12 Jun 2025 14:38:10 +0200 Subject: [PATCH 10/24] Add XGrid ravel_index and unravel_index --- parcels/xgrid.py | 45 ++++++++++++++++++++++++++++++++++++++++++ tests/v4/test_xgrid.py | 35 +++++++++++++++++++++++++++----- 2 files changed, 75 insertions(+), 5 deletions(-) diff --git a/parcels/xgrid.py b/parcels/xgrid.py index cb93c2ec6f..b6155ad436 100644 --- a/parcels/xgrid.py +++ b/parcels/xgrid.py @@ -178,6 +178,51 @@ def _gtype(self): def search(self, z, y, x, ei=None, search2D=False): ... + def ravel_index(self, zi, yi, xi): + """ + Converts a z, y, and x index into a single encoded index. + + Parameters + ---------- + zi : int + Vertical index. + yi : int + Latitude index. + xi : int + Longitude index. + + Returns + ------- + int + Encoded index. + """ + return xi + self.xdim * yi + self.xdim * self.ydim * zi + + def unravel_index(self, ei): + """ + Converts a single encoded index back into a vertical index and face index. + + Parameters + ---------- + ei : int + Encoded index to be unraveled. + + Returns + ------- + zi : int + Vertical index. + yi : int + Latitude index. + xi : int + Longitude index. + """ + zi = ei // (self.xdim * self.ydim) + ei = ei % (self.xdim * self.ydim) + + yi = ei // self.xdim + xi = ei % self.xdim + return zi, yi, xi + def get_axis_from_dim_name(axes: _XGCM_AXES, dim: str) -> _AXIS_DIRECTION | None: """For a given dimension name in a grid, returns the direction axis it is on.""" diff --git a/tests/v4/test_xgrid.py b/tests/v4/test_xgrid.py index 3d7aa24059..b97638882f 100644 --- a/tests/v4/test_xgrid.py +++ b/tests/v4/test_xgrid.py @@ -121,16 +121,41 @@ def test_invalid_lon_lat(ds): XGrid(xgcm.Grid(ds, periodic=False)) +def test_xgrid_ravel_unravel_index(): + ds = datasets["ds_2d_left"] + grid = XGrid(xgcm.Grid(ds, periodic=False)) + + xdim = grid.xdim + ydim = grid.ydim + zdim = grid.zdim + + encountered_eis = [] + for xi in range(xdim): + for yi in range(ydim): + for zi in range(zdim): + ei = grid.ravel_index(zi, yi, xi) + zi_test, yi_test, xi_test = grid.unravel_index(ei) + assert xi == xi_test, f"Expected xi {xi} but got {xi_test} for ei {ei}" + assert yi == yi_test, f"Expected yi {yi} but got {yi_test} for ei {ei}" + assert zi == zi_test, f"Expected zi {zi} but got {zi_test} for ei {ei}" + encountered_eis.append(ei) + + encountered_eis = sorted(encountered_eis) + assert len(set(encountered_eis)) == len(encountered_eis), "Raveled indices are not unique." + assert np.allclose(np.diff(np.array(encountered_eis)), 1), "Raveled indices are not consecutive integers." + assert encountered_eis[0] == 0, "Raveled indices do not start at 0." + + def test_iterate_over_cells(): - ny = 3 # Number of cells in the y-direction - nx = 6 # Number of cells in the x-direction - lon = np.arange(ny + 1) - lat = np.arange(nx + 1) + ydim = 3 # Number of cells in the y-direction + xdim = 6 # Number of cells in the x-direction + lon = np.arange(ydim + 1) + lat = np.arange(xdim + 1) LAT, LON = np.meshgrid(lat, lon, indexing="ij") # Call the function and collect the output cells = list(_iterate_over_cells(lat=LAT, lon=LON)) - assert len(cells) == ny * nx, "Number of cells does not match expected." + assert len(cells) == ydim * xdim, "Number of cells does not match expected." for cell in cells: _assert_point_is("east", 1, cell[0], cell[1]) From f4e0aad60d91b6e2a9a6c36885e381ee8b4868f0 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Thu, 12 Jun 2025 16:38:17 +0200 Subject: [PATCH 11/24] Add 1D search --- parcels/xgrid.py | 44 ++++++++++++++++++++++++++++++++++++++++-- tests/v4/test_xgrid.py | 6 +++--- 2 files changed, 45 insertions(+), 5 deletions(-) diff --git a/parcels/xgrid.py b/parcels/xgrid.py index b6155ad436..4f82425c26 100644 --- a/parcels/xgrid.py +++ b/parcels/xgrid.py @@ -176,7 +176,21 @@ def _gtype(self): else: return GridType.CurvilinearSGrid - def search(self, z, y, x, ei=None, search2D=False): ... + def search(self, z, y, x, ei=None, search2D=False): + ds = self.xgcm_grid._ds + + if ds.lon.ndim == 1: + yi, bcoord_y = _search_1d_array(ds.lat.values, y) + xi, bcoord_x = _search_1d_array(ds.lon.values, x) + + if search2D: + zi = 0 + else: + zi, _ = _search_1d_array(ds.depth.values, z) + + return (zi, yi, xi), np.array([bcoord_y, bcoord_x, 1 - bcoord_y, 1 - bcoord_x]) + + raise NotImplementedError("Searching in 2D arrays is not implemented yet.") def ravel_index(self, zi, yi, xi): """ @@ -340,7 +354,7 @@ def assert_valid_lon_lat(da_lon, da_lat, axes: _XGCM_AXES): ) -def _iterate_over_cells( +def _generate_cells( *, lat: np.ndarray, lon: np.ndarray, @@ -375,3 +389,29 @@ def _iterate_over_cells( [lat[y + 1, x], lon[y + 1, x]], # top left ] ) + + +def _search_1d_array( + arr: np.array, + x: float, +) -> tuple[int, int]: + """ + Searches for the particle location in a 1D array and return barycentric coordinate along dimension. + + Parameters + ---------- + arr : np.array + 1D array (assumed to be ascending) to search in. + x : float + Position in the 1D array to search for. + + Returns + ------- + int + Index of the element just before the position x in the array. + float + Barycentric coordinate. + """ + i = np.argmin(arr < x) + barry = (x - arr[i]) / (arr[i + 1] - arr[i]) + return i, barry diff --git a/tests/v4/test_xgrid.py b/tests/v4/test_xgrid.py index b97638882f..250f0c6285 100644 --- a/tests/v4/test_xgrid.py +++ b/tests/v4/test_xgrid.py @@ -12,7 +12,7 @@ from parcels.tools.converters import TimeConverter from parcels.xgrid import ( XGrid, - _iterate_over_cells, + _generate_cells, ) GridTestCase = namedtuple("GridTestCase", ["Grid", "attr", "expected"]) @@ -146,7 +146,7 @@ def test_xgrid_ravel_unravel_index(): assert encountered_eis[0] == 0, "Raveled indices do not start at 0." -def test_iterate_over_cells(): +def test_generate_cells(): ydim = 3 # Number of cells in the y-direction xdim = 6 # Number of cells in the x-direction lon = np.arange(ydim + 1) @@ -154,7 +154,7 @@ def test_iterate_over_cells(): LAT, LON = np.meshgrid(lat, lon, indexing="ij") # Call the function and collect the output - cells = list(_iterate_over_cells(lat=LAT, lon=LON)) + cells = list(_generate_cells(lat=LAT, lon=LON)) assert len(cells) == ydim * xdim, "Number of cells does not match expected." for cell in cells: From e8a58aabe6ee894d63708e4d6a1b36cc2cb10569 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Fri, 13 Jun 2025 17:09:45 +0200 Subject: [PATCH 12/24] Remove unused type annotation --- parcels/xgrid.py | 1 - 1 file changed, 1 deletion(-) diff --git a/parcels/xgrid.py b/parcels/xgrid.py index 4f82425c26..b6e2d2dd70 100644 --- a/parcels/xgrid.py +++ b/parcels/xgrid.py @@ -10,7 +10,6 @@ from parcels.tools.converters import TimeConverter _AXIS_DIRECTION = Literal["X", "Y", "Z", "T"] -_AXIS_DIRECTION_SPATIAL = Literal["X", "Y", "Z"] _AXIS_POSITION = Literal["center", "left", "right", "inner", "outer"] _XGCM_AXES = Mapping[_AXIS_DIRECTION, xgcm.Axis] From ec99f1a95fa5e97a25cbd828d3e79e5bc3d83d5c Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Mon, 23 Jun 2025 16:10:10 +0200 Subject: [PATCH 13/24] copy _search_indices_curvilinear and rename --- parcels/_index_search.py | 90 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) diff --git a/parcels/_index_search.py b/parcels/_index_search.py index 805e5e1eb8..11e8411bcb 100644 --- a/parcels/_index_search.py +++ b/parcels/_index_search.py @@ -270,6 +270,96 @@ def _search_indices_rectilinear( return (zeta, eta, xsi, _ei) +## TODO : Still need to implement the search_indices_curvilinear +def _search_indices_curvilinear_2d(field: Field, time, z, y, x, ti, particle=None, search2D=False): + if particle: + zi, yi, xi = field.unravel_index(particle.ei) + else: + xi = int(field.xdim / 2) - 1 + yi = int(field.ydim / 2) - 1 + xsi = eta = -1.0 + grid = field.grid + invA = np.array([[1, 0, 0, 0], [-1, 1, 0, 0], [-1, 0, 0, 1], [1, -1, 1, -1]]) + maxIterSearch = 1e6 + it = 0 + tol = 1.0e-10 + if not grid.zonal_periodic: + if x < field.lonlat_minmax[0] or x > field.lonlat_minmax[1]: + if grid.lon[0, 0] < grid.lon[0, -1]: + _raise_field_out_of_bound_error(z, y, x) + elif x < grid.lon[0, 0] and x > grid.lon[0, -1]: # This prevents from crashing in [160, -160] + _raise_field_out_of_bound_error(z, y, x) + if y < field.lonlat_minmax[2] or y > field.lonlat_minmax[3]: + _raise_field_out_of_bound_error(z, y, x) + + while xsi < -tol or xsi > 1 + tol or eta < -tol or eta > 1 + tol: + px = np.array([grid.lon[yi, xi], grid.lon[yi, xi + 1], grid.lon[yi + 1, xi + 1], grid.lon[yi + 1, xi]]) + if grid.mesh == "spherical": + px[0] = px[0] + 360 if px[0] < x - 225 else px[0] + px[0] = px[0] - 360 if px[0] > x + 225 else px[0] + px[1:] = np.where(px[1:] - px[0] > 180, px[1:] - 360, px[1:]) + px[1:] = np.where(-px[1:] + px[0] > 180, px[1:] + 360, px[1:]) + py = np.array([grid.lat[yi, xi], grid.lat[yi, xi + 1], grid.lat[yi + 1, xi + 1], grid.lat[yi + 1, xi]]) + a = np.dot(invA, px) + b = np.dot(invA, py) + + aa = a[3] * b[2] - a[2] * b[3] + bb = a[3] * b[0] - a[0] * b[3] + a[1] * b[2] - a[2] * b[1] + x * b[3] - y * a[3] + cc = a[1] * b[0] - a[0] * b[1] + x * b[1] - y * a[1] + if abs(aa) < 1e-12: # Rectilinear cell, or quasi + eta = -cc / bb + else: + det2 = bb * bb - 4 * aa * cc + if det2 > 0: # so, if det is nan we keep the xsi, eta from previous iter + det = np.sqrt(det2) + eta = (-bb + det) / (2 * aa) + if abs(a[1] + a[3] * eta) < 1e-12: # this happens when recti cell rotated of 90deg + xsi = ((y - py[0]) / (py[1] - py[0]) + (y - py[3]) / (py[2] - py[3])) * 0.5 + else: + xsi = (x - a[0] - a[2] * eta) / (a[1] + a[3] * eta) + if xsi < 0 and eta < 0 and xi == 0 and yi == 0: + _raise_field_out_of_bound_error(0, y, x) + if xsi > 1 and eta > 1 and xi == grid.xdim - 1 and yi == grid.ydim - 1: + _raise_field_out_of_bound_error(0, y, x) + if xsi < -tol: + xi -= 1 + elif xsi > 1 + tol: + xi += 1 + if eta < -tol: + yi -= 1 + elif eta > 1 + tol: + yi += 1 + (yi, xi) = _reconnect_bnd_indices(yi, xi, grid.ydim, grid.xdim, grid.mesh) + it += 1 + if it > maxIterSearch: + print(f"Correct cell not found after {maxIterSearch} iterations") + _raise_field_out_of_bound_error(0, y, x) + xsi = max(0.0, xsi) + eta = max(0.0, eta) + xsi = min(1.0, xsi) + eta = min(1.0, eta) + + if grid.zdim > 1 and not search2D: + if grid._gtype == GridType.CurvilinearZGrid: + try: + (zi, zeta) = search_indices_vertical_z(field.grid, field.gridindexingtype, z) + except FieldOutOfBoundError: + _raise_field_out_of_bound_error(z, y, x) + elif grid._gtype == GridType.CurvilinearSGrid: + (zi, zeta) = search_indices_vertical_s(field.grid, field.interp_method, time, z, y, x, ti, yi, xi, eta, xsi) + else: + zi = -1 + zeta = 0 + + if not ((0 <= xsi <= 1) and (0 <= eta <= 1) and (0 <= zeta <= 1)): + _raise_field_sampling_error(z, y, x) + + if particle: + particle.ei[field.igrid] = field.ravel_index(zi, yi, xi) + + return (zeta, eta, xsi, zi, yi, xi) + + ## TODO : Still need to implement the search_indices_curvilinear def _search_indices_curvilinear(field: Field, time, z, y, x, ti, particle=None, search2D=False): if particle: From 346cbcaab498fdbe189004f4b20a4405e1d9fb2f Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Mon, 23 Jun 2025 17:22:13 +0200 Subject: [PATCH 14/24] Add function to do 2D curvilinear index search --- parcels/_index_search.py | 74 ++++++++++++++++------------------- parcels/xgrid.py | 4 ++ tests/v4/test_index_search.py | 54 +++++++++++++++++++++++++ 3 files changed, 92 insertions(+), 40 deletions(-) create mode 100644 tests/v4/test_index_search.py diff --git a/parcels/_index_search.py b/parcels/_index_search.py index 11e8411bcb..9320463e3e 100644 --- a/parcels/_index_search.py +++ b/parcels/_index_search.py @@ -17,6 +17,7 @@ _raise_field_sampling_error, _raise_time_extrapolation_error, ) +from parcels.xgrid import XGrid from .grid import GridType @@ -270,35 +271,43 @@ def _search_indices_rectilinear( return (zeta, eta, xsi, _ei) -## TODO : Still need to implement the search_indices_curvilinear -def _search_indices_curvilinear_2d(field: Field, time, z, y, x, ti, particle=None, search2D=False): - if particle: - zi, yi, xi = field.unravel_index(particle.ei) - else: - xi = int(field.xdim / 2) - 1 - yi = int(field.ydim / 2) - 1 +def _search_indices_curvilinear_2d( + grid: XGrid, y: float, x: float, yi_guess: int | None = None, xi_guess: int | None = None +): + yi, xi = yi_guess, xi_guess + if yi is None: + yi = int(grid.ydim / 2) - 1 + + if xi is None: + xi = int(grid.xdim / 2) - 1 + xsi = eta = -1.0 - grid = field.grid - invA = np.array([[1, 0, 0, 0], [-1, 1, 0, 0], [-1, 0, 0, 1], [1, -1, 1, -1]]) + invA = np.array( + [ + [1, 0, 0, 0], + [-1, 1, 0, 0], + [-1, 0, 0, 1], + [1, -1, 1, -1], + ] + ) maxIterSearch = 1e6 it = 0 tol = 1.0e-10 - if not grid.zonal_periodic: - if x < field.lonlat_minmax[0] or x > field.lonlat_minmax[1]: - if grid.lon[0, 0] < grid.lon[0, -1]: - _raise_field_out_of_bound_error(z, y, x) - elif x < grid.lon[0, 0] and x > grid.lon[0, -1]: # This prevents from crashing in [160, -160] - _raise_field_out_of_bound_error(z, y, x) - if y < field.lonlat_minmax[2] or y > field.lonlat_minmax[3]: - _raise_field_out_of_bound_error(z, y, x) + + # # ! Error handling for out of bounds + # TODO: Re-enable in some capacity + # if x < field.lonlat_minmax[0] or x > field.lonlat_minmax[1]: + # if grid.lon[0, 0] < grid.lon[0, -1]: + # _raise_field_out_of_bound_error(y, x) + # elif x < grid.lon[0, 0] and x > grid.lon[0, -1]: # This prevents from crashing in [160, -160] + # _raise_field_out_of_bound_error(z, y, x) + + # if y < field.lonlat_minmax[2] or y > field.lonlat_minmax[3]: + # _raise_field_out_of_bound_error(z, y, x) while xsi < -tol or xsi > 1 + tol or eta < -tol or eta > 1 + tol: px = np.array([grid.lon[yi, xi], grid.lon[yi, xi + 1], grid.lon[yi + 1, xi + 1], grid.lon[yi + 1, xi]]) - if grid.mesh == "spherical": - px[0] = px[0] + 360 if px[0] < x - 225 else px[0] - px[0] = px[0] - 360 if px[0] > x + 225 else px[0] - px[1:] = np.where(px[1:] - px[0] > 180, px[1:] - 360, px[1:]) - px[1:] = np.where(-px[1:] + px[0] > 180, px[1:] + 360, px[1:]) + py = np.array([grid.lat[yi, xi], grid.lat[yi, xi + 1], grid.lat[yi + 1, xi + 1], grid.lat[yi + 1, xi]]) a = np.dot(invA, px) b = np.dot(invA, py) @@ -339,25 +348,10 @@ def _search_indices_curvilinear_2d(field: Field, time, z, y, x, ti, particle=Non xsi = min(1.0, xsi) eta = min(1.0, eta) - if grid.zdim > 1 and not search2D: - if grid._gtype == GridType.CurvilinearZGrid: - try: - (zi, zeta) = search_indices_vertical_z(field.grid, field.gridindexingtype, z) - except FieldOutOfBoundError: - _raise_field_out_of_bound_error(z, y, x) - elif grid._gtype == GridType.CurvilinearSGrid: - (zi, zeta) = search_indices_vertical_s(field.grid, field.interp_method, time, z, y, x, ti, yi, xi, eta, xsi) - else: - zi = -1 - zeta = 0 - - if not ((0 <= xsi <= 1) and (0 <= eta <= 1) and (0 <= zeta <= 1)): - _raise_field_sampling_error(z, y, x) + if not ((0 <= xsi <= 1) and (0 <= eta <= 1)): + _raise_field_sampling_error(y, x) - if particle: - particle.ei[field.igrid] = field.ravel_index(zi, yi, xi) - - return (zeta, eta, xsi, zi, yi, xi) + return (eta, xsi, yi, xi) ## TODO : Still need to implement the search_indices_curvilinear diff --git a/parcels/xgrid.py b/parcels/xgrid.py index b6e2d2dd70..40d2b2c61d 100644 --- a/parcels/xgrid.py +++ b/parcels/xgrid.py @@ -41,6 +41,10 @@ class XGrid(BaseGrid): Class to represent a structured grid in Parcels. Wraps a xgcm-like Grid object (we use a trimmed down version of the xgcm.Grid class that is vendored with Parcels). This class provides methods and properties required for indexing and interpolating on the grid. + + Assumptions: + - If using Parcels in the context of a periodic simulation, the provided grid already has a halo + """ def __init__(self, grid: xgcm.Grid, mesh="flat"): diff --git a/tests/v4/test_index_search.py b/tests/v4/test_index_search.py new file mode 100644 index 0000000000..f39df131d1 --- /dev/null +++ b/tests/v4/test_index_search.py @@ -0,0 +1,54 @@ +import numpy as np +import pytest + +from parcels import xgcm +from parcels._datasets.structured.generic import datasets +from parcels._index_search import _search_indices_curvilinear_2d +from parcels.field import Field +from parcels.xgrid import ( + XGrid, +) + + +@pytest.fixture +def field_cone(): + ds = datasets["2d_left_unrolled_cone"] + grid = XGrid(xgcm.Grid(ds, periodic=False)) + field = Field( + name="test_field", + data=ds["data_g"], + grid=grid, + ) + return field + + +def test_grid_indexing_fpoints(field_cone): + grid = field_cone.grid + + for yi_expected in range(grid.ydim - 1): + for xi_expected in range(grid.xdim - 1): + x = grid.lon[yi_expected, xi_expected] + 0.00001 + y = grid.lat[yi_expected, xi_expected] + 0.00001 + + eta, xsi, yi, xi = _search_indices_curvilinear_2d(grid, y, x) + if eta > 0.9: + yi_expected -= 1 + if xsi > 0.9: + xi_expected -= 1 + assert yi == yi_expected, f"Expected yi {yi_expected} but got {yi}" + assert xi == xi_expected, f"Expected xi {xi_expected} but got {xi}" + + cell_lon = [ + grid.lon[yi, xi], + grid.lon[yi, xi + 1], + grid.lon[yi + 1, xi + 1], + grid.lon[yi + 1, xi], + ] + cell_lat = [ + grid.lat[yi, xi], + grid.lat[yi, xi + 1], + grid.lat[yi + 1, xi + 1], + grid.lat[yi + 1, xi], + ] + assert x > np.min(cell_lon) and x < np.max(cell_lon) + assert y > np.min(cell_lat) and y < np.max(cell_lat) From 1531699180e77499c279762e8c6e0b1aa7f3385e Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Tue, 24 Jun 2025 11:33:53 +0200 Subject: [PATCH 15/24] Implement XGrid.search with tests for 2D lon lat --- parcels/_index_search.py | 3 ++- parcels/xgrid.py | 27 ++++++++++++------- tests/v4/test_xgrid.py | 58 +++++++++++++++++++++++++++++++++++++--- 3 files changed, 74 insertions(+), 14 deletions(-) diff --git a/parcels/_index_search.py b/parcels/_index_search.py index 9320463e3e..2d43da79b8 100644 --- a/parcels/_index_search.py +++ b/parcels/_index_search.py @@ -17,11 +17,12 @@ _raise_field_sampling_error, _raise_time_extrapolation_error, ) -from parcels.xgrid import XGrid from .grid import GridType if TYPE_CHECKING: + from parcels.xgrid import XGrid + from .field import Field # from .grid import Grid diff --git a/parcels/xgrid.py b/parcels/xgrid.py index 40d2b2c61d..1a391ad600 100644 --- a/parcels/xgrid.py +++ b/parcels/xgrid.py @@ -6,6 +6,7 @@ import xarray as xr from parcels import xgcm +from parcels._index_search import _search_indices_curvilinear_2d from parcels.basegrid import BaseGrid from parcels.tools.converters import TimeConverter @@ -182,18 +183,26 @@ def _gtype(self): def search(self, z, y, x, ei=None, search2D=False): ds = self.xgcm_grid._ds + if search2D: + zi = 0 + else: + zi, _ = _search_1d_array(ds.depth.values, z) + if ds.lon.ndim == 1: - yi, bcoord_y = _search_1d_array(ds.lat.values, y) - xi, bcoord_x = _search_1d_array(ds.lon.values, x) + yi, eta = _search_1d_array(ds.lat.values, y) + xi, xsi = _search_1d_array(ds.lon.values, x) + return (zi, yi, xi), np.array([eta, xsi, 1 - eta, 1 - xsi]) - if search2D: - zi = 0 - else: - zi, _ = _search_1d_array(ds.depth.values, z) + yi, xi = None, None + if ei is not None: + _, yi, xi = self.unravel_index(ei) + + if ds.lon.ndim == 2: + eta, xsi, yi, xi = _search_indices_curvilinear_2d(self, y, x, yi, xi) - return (zi, yi, xi), np.array([bcoord_y, bcoord_x, 1 - bcoord_y, 1 - bcoord_x]) + return (zi, yi, xi), np.array([eta, xsi, 1 - eta, 1 - xsi]) - raise NotImplementedError("Searching in 2D arrays is not implemented yet.") + raise NotImplementedError("Searching in >2D lon/lat arrays is not implemented yet.") def ravel_index(self, zi, yi, xi): """ @@ -415,6 +424,6 @@ def _search_1d_array( float Barycentric coordinate. """ - i = np.argmin(arr < x) + i = np.argmin(arr <= x) - 1 barry = (x - arr[i]) / (arr[i + 1] - arr[i]) return i, barry diff --git a/tests/v4/test_xgrid.py b/tests/v4/test_xgrid.py index 250f0c6285..a92f774187 100644 --- a/tests/v4/test_xgrid.py +++ b/tests/v4/test_xgrid.py @@ -10,10 +10,7 @@ from parcels._datasets.structured.generic import T, X, Y, Z, datasets from parcels.grid import Grid as OldGrid from parcels.tools.converters import TimeConverter -from parcels.xgrid import ( - XGrid, - _generate_cells, -) +from parcels.xgrid import XGrid, _generate_cells, _search_1d_array GridTestCase = namedtuple("GridTestCase", ["Grid", "attr", "expected"]) @@ -187,3 +184,56 @@ def _assert_point_is( raise ValueError(f"Invalid method: {direction}") np.testing.assert_allclose(reference_cell + delta, test_cell) + + +@pytest.mark.parametrize( + "ds", + [ + pytest.param(datasets["ds_2d_left"], id="1D lon/lat"), + pytest.param(datasets["2d_left_rotated"], id="2D lon/lat"), + ], +) # for key, ds in datasets.items()]) +def test_xgrid_search_cpoints(ds): + grid = XGrid(xgcm.Grid(ds, periodic=False)) + lat_array, lon_array = get_2d_fpoint_mesh(grid) + lat_array, lon_array = corner_to_cell_center_points(lat_array, lon_array) + + for xi in range(grid.xdim - 1): + for yi in range(grid.ydim - 1): + lat, lon = lat_array[yi, xi], lon_array[yi, xi] + (zi_test, yi_test, xi_test), bcoords = grid.search(0, lat, lon, ei=None, search2D=True) + assert xi == xi_test + assert yi == yi_test + assert zi_test == 0 + + # assert np.isclose(bcoords[0], 0.5) #? Should this not be the case with the cell center points? + # assert np.isclose(bcoords[1], 0.5) + + +def get_2d_fpoint_mesh(grid: XGrid): + lat, lon = grid.lat, grid.lon + if lon.ndim == 1: + lat, lon = np.meshgrid(lat, lon, indexing="ij") + return lat, lon + + +def corner_to_cell_center_points(lat, lon): + """Convert F points to C points.""" + lon_c = (lon[:-1, :-1] + lon[:-1, 1:]) / 2 + lat_c = (lat[:-1, :-1] + lat[1:, :-1]) / 2 + return lat_c, lon_c + + +@pytest.mark.parametrize( + "array, x, expected_xi, expected_xsi", + [ + (np.array([1, 2, 3, 4, 5]), 1.1, 0, 0.1), + (np.array([1, 2, 3, 4, 5]), 2.1, 1, 0.1), + (np.array([1, 2, 3, 4, 5]), 3.1, 2, 0.1), + (np.array([1, 2, 3, 4, 5]), 4.5, 3, 0.5), + ], +) +def test_search_1d_array(array, x, expected_xi, expected_xsi): + xi, xsi = _search_1d_array(array, x) + assert xi == expected_xi + assert np.isclose(xsi, expected_xsi) From 73feee05e4ce6ab0bd44b1a987c3dcd9fcab24ca Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Tue, 24 Jun 2025 13:42:49 +0200 Subject: [PATCH 16/24] Patch test --- parcels/xgrid.py | 2 ++ tests/v4/test_xgrid.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/parcels/xgrid.py b/parcels/xgrid.py index 1a391ad600..e16a767eff 100644 --- a/parcels/xgrid.py +++ b/parcels/xgrid.py @@ -410,6 +410,8 @@ def _search_1d_array( """ Searches for the particle location in a 1D array and return barycentric coordinate along dimension. + Assumes particle position x is within the bounds of the array, and array is increasing. + Parameters ---------- arr : np.array diff --git a/tests/v4/test_xgrid.py b/tests/v4/test_xgrid.py index a92f774187..b23cdcd4d0 100644 --- a/tests/v4/test_xgrid.py +++ b/tests/v4/test_xgrid.py @@ -88,7 +88,7 @@ def test_invalid_xgrid_field_array(): ... -def test_invalid_lon_lat(ds): +def test_invalid_lon_lat(): """Stress test the grid initialiser by creating incompatible datasets that test the edge cases""" ds = datasets["ds_2d_left"].copy() ds["lon"], ds["lat"] = xr.broadcast(ds["YC"], ds["XC"]) From fadbfbf968e33e053c3f87c25519cf68c157f491 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Tue, 24 Jun 2025 13:43:36 +0200 Subject: [PATCH 17/24] Remove unused _generate_cells helper --- parcels/xgrid.py | 37 --------------------------------- tests/v4/test_xgrid.py | 46 +----------------------------------------- 2 files changed, 1 insertion(+), 82 deletions(-) diff --git a/parcels/xgrid.py b/parcels/xgrid.py index e16a767eff..7f6fd0b2f0 100644 --- a/parcels/xgrid.py +++ b/parcels/xgrid.py @@ -366,43 +366,6 @@ def assert_valid_lon_lat(da_lon, da_lat, axes: _XGCM_AXES): ) -def _generate_cells( - *, - lat: np.ndarray, - lon: np.ndarray, -): - """ - Takes in 2 2D arrays representing the F points of a grid and returns a generator that yields - the individual cells of the grid as 2D arrays of the corners of the cells. - - Parameters - ---------- - lon : np.ndarray - 2D array of the longitude F points of the grid of shape (Y, X) - lat : np.ndarray - 2D array of the latitude F points of the grid of shape (Y, X). - - Yields - ------ - np.ndarray - 2D array of shape (4, 2) representing the corners of the cell in the order: - bottom left, bottom right, top right, top left. Output is provided in lat, lon order. - """ - assert lon.ndim == 2 and lat.ndim == 2, "lon and lat must be 2D arrays." - assert lon.shape == lat.shape, "lon and lat must have the same shape." - - for y in range(lon.shape[0] - 1): - for x in range(lon.shape[1] - 1): - yield np.array( - [ - [lat[y, x], lon[y, x]], # bottom left - [lat[y, x + 1], lon[y, x + 1]], # bottom right - [lat[y + 1, x + 1], lon[y + 1, x + 1]], # top right - [lat[y + 1, x], lon[y + 1, x]], # top left - ] - ) - - def _search_1d_array( arr: np.array, x: float, diff --git a/tests/v4/test_xgrid.py b/tests/v4/test_xgrid.py index b23cdcd4d0..ce0fb4b327 100644 --- a/tests/v4/test_xgrid.py +++ b/tests/v4/test_xgrid.py @@ -1,5 +1,4 @@ from collections import namedtuple -from typing import Literal import numpy as np import pytest @@ -10,7 +9,7 @@ from parcels._datasets.structured.generic import T, X, Y, Z, datasets from parcels.grid import Grid as OldGrid from parcels.tools.converters import TimeConverter -from parcels.xgrid import XGrid, _generate_cells, _search_1d_array +from parcels.xgrid import XGrid, _search_1d_array GridTestCase = namedtuple("GridTestCase", ["Grid", "attr", "expected"]) @@ -143,49 +142,6 @@ def test_xgrid_ravel_unravel_index(): assert encountered_eis[0] == 0, "Raveled indices do not start at 0." -def test_generate_cells(): - ydim = 3 # Number of cells in the y-direction - xdim = 6 # Number of cells in the x-direction - lon = np.arange(ydim + 1) - lat = np.arange(xdim + 1) - LAT, LON = np.meshgrid(lat, lon, indexing="ij") - - # Call the function and collect the output - cells = list(_generate_cells(lat=LAT, lon=LON)) - assert len(cells) == ydim * xdim, "Number of cells does not match expected." - - for cell in cells: - _assert_point_is("east", 1, cell[0], cell[1]) - _assert_point_is("north", 1, cell[1], cell[2]) - _assert_point_is("west", 1, cell[2], cell[3]) - - -def test__assert_point_is(): - _assert_point_is("east", 1, np.array([0, 0]), np.array([0, 1])) - _assert_point_is("west", 1, np.array([0, 1]), np.array([0, 0])) - _assert_point_is("north", 1, np.array([0, 0]), np.array([1, 0])) - _assert_point_is("south", 1, np.array([1, 0]), np.array([0, 0])) - - -def _assert_point_is( - direction: Literal["east", "west", "north", "south"], by: int, reference_cell: np.ndarray, test_cell: np.ndarray -): - """cell1 and cell2 are arrays of (lat, lon)""" - match direction: - case "east": - delta = np.array([0, by]) - case "west": - delta = np.array([0, -by]) - case "north": - delta = np.array([by, 0]) - case "south": - delta = np.array([-by, 0]) - case _: - raise ValueError(f"Invalid method: {direction}") - - np.testing.assert_allclose(reference_cell + delta, test_cell) - - @pytest.mark.parametrize( "ds", [ From 2085971b766044872e18eb30a7f1d5d064bf560c Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Tue, 24 Jun 2025 15:37:55 +0200 Subject: [PATCH 18/24] Review feedback --- parcels/xgrid.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/parcels/xgrid.py b/parcels/xgrid.py index 7f6fd0b2f0..4f2304eb20 100644 --- a/parcels/xgrid.py +++ b/parcels/xgrid.py @@ -52,7 +52,7 @@ def __init__(self, grid: xgcm.Grid, mesh="flat"): self.xgcm_grid = grid self.mesh = mesh ds = grid._ds - assert_valid_lon_lat(ds["lon"], ds["lat"], grid.axes) + assert_valid_lat_lon(ds["lat"], ds["lon"], grid.axes) # ! Not ideal... Triggers computation on a throwaway item. Keeping for now for v3 compat, will be removed in v4. self.lonlat_minmax = np.array( @@ -226,7 +226,7 @@ def ravel_index(self, zi, yi, xi): def unravel_index(self, ei): """ - Converts a single encoded index back into a vertical index and face index. + Converts a single encoded index back into a Z, Y, and X indices. Parameters ---------- @@ -303,7 +303,7 @@ def assert_valid_field_array(da: xr.DataArray, axes: _XGCM_AXES): ) -def assert_valid_lon_lat(da_lon, da_lat, axes: _XGCM_AXES): +def assert_valid_lat_lon(da_lat, da_lon, axes: _XGCM_AXES): """ Asserts that the provided longitude and latitude DataArrays are defined appropriately on the F points to match the internal representation in Parcels. @@ -371,7 +371,7 @@ def _search_1d_array( x: float, ) -> tuple[int, int]: """ - Searches for the particle location in a 1D array and return barycentric coordinate along dimension. + Searches for the particle location in a 1D array and returns barycentric coordinate along dimension. Assumes particle position x is within the bounds of the array, and array is increasing. @@ -390,5 +390,5 @@ def _search_1d_array( Barycentric coordinate. """ i = np.argmin(arr <= x) - 1 - barry = (x - arr[i]) / (arr[i + 1] - arr[i]) - return i, barry + bcoord = (x - arr[i]) / (arr[i + 1] - arr[i]) + return i, bcoord From 1166b02a284b21c33e9d86124f5c09f424383ad0 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Tue, 24 Jun 2025 15:42:35 +0200 Subject: [PATCH 19/24] Patch return of `XGrid.search()` to match BaseGrid --- parcels/xgrid.py | 4 ++-- tests/v4/test_xgrid.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/parcels/xgrid.py b/parcels/xgrid.py index 4f2304eb20..6cf75e95dd 100644 --- a/parcels/xgrid.py +++ b/parcels/xgrid.py @@ -191,7 +191,7 @@ def search(self, z, y, x, ei=None, search2D=False): if ds.lon.ndim == 1: yi, eta = _search_1d_array(ds.lat.values, y) xi, xsi = _search_1d_array(ds.lon.values, x) - return (zi, yi, xi), np.array([eta, xsi, 1 - eta, 1 - xsi]) + return self.ravel_index(zi, yi, xi), np.array([eta, xsi, 1 - eta, 1 - xsi]) yi, xi = None, None if ei is not None: @@ -200,7 +200,7 @@ def search(self, z, y, x, ei=None, search2D=False): if ds.lon.ndim == 2: eta, xsi, yi, xi = _search_indices_curvilinear_2d(self, y, x, yi, xi) - return (zi, yi, xi), np.array([eta, xsi, 1 - eta, 1 - xsi]) + return self.ravel_index(zi, yi, xi), np.array([eta, xsi, 1 - eta, 1 - xsi]) raise NotImplementedError("Searching in >2D lon/lat arrays is not implemented yet.") diff --git a/tests/v4/test_xgrid.py b/tests/v4/test_xgrid.py index ce0fb4b327..fd2da4b90e 100644 --- a/tests/v4/test_xgrid.py +++ b/tests/v4/test_xgrid.py @@ -157,7 +157,8 @@ def test_xgrid_search_cpoints(ds): for xi in range(grid.xdim - 1): for yi in range(grid.ydim - 1): lat, lon = lat_array[yi, xi], lon_array[yi, xi] - (zi_test, yi_test, xi_test), bcoords = grid.search(0, lat, lon, ei=None, search2D=True) + ei, bcoords = grid.search(0, lat, lon, ei=None, search2D=True) + zi_test, yi_test, xi_test = grid.unravel_index(ei) assert xi == xi_test assert yi == yi_test assert zi_test == 0 From 6346183ec1849df7981ee3cd478c646edb1bd153 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Tue, 24 Jun 2025 15:47:01 +0200 Subject: [PATCH 20/24] Assert 1D lon and lat arrays are strictly monotonically increasing --- parcels/xgrid.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/parcels/xgrid.py b/parcels/xgrid.py index 6cf75e95dd..d312382d79 100644 --- a/parcels/xgrid.py +++ b/parcels/xgrid.py @@ -353,6 +353,13 @@ def assert_valid_lat_lon(da_lat, da_lon, axes: _XGCM_AXES): f"Latitude DataArray {da_lat.name!r} with dims {da_lat.dims} is not associated with the Y axis." ) + if not np.all(np.diff(da_lon.values) > 0): + raise ValueError( + f"Longitude DataArray {da_lon.name!r} with dims {da_lon.dims} must be strictly increasing." + ) + if not np.all(np.diff(da_lat.values) > 0): + raise ValueError(f"Latitude DataArray {da_lat.name!r} with dims {da_lat.dims} must be strictly increasing.") + if da_lon.ndim == 2: if da_lon.dims != da_lat.dims: raise ValueError( @@ -373,12 +380,14 @@ def _search_1d_array( """ Searches for the particle location in a 1D array and returns barycentric coordinate along dimension. - Assumes particle position x is within the bounds of the array, and array is increasing. + Assumptions: + - particle position x is within the bounds of the array + - array is strictly monotonically increasing. Parameters ---------- arr : np.array - 1D array (assumed to be ascending) to search in. + 1D array to search in. x : float Position in the 1D array to search for. From 891ecaa8dccb25f0012836f02f01e171b8d57913 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Wed, 25 Jun 2025 16:03:03 +0200 Subject: [PATCH 21/24] Patch XGrid.search return order --- parcels/xgrid.py | 4 ++-- tests/v4/test_xgrid.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/parcels/xgrid.py b/parcels/xgrid.py index d312382d79..fb42d68c1b 100644 --- a/parcels/xgrid.py +++ b/parcels/xgrid.py @@ -191,7 +191,7 @@ def search(self, z, y, x, ei=None, search2D=False): if ds.lon.ndim == 1: yi, eta = _search_1d_array(ds.lat.values, y) xi, xsi = _search_1d_array(ds.lon.values, x) - return self.ravel_index(zi, yi, xi), np.array([eta, xsi, 1 - eta, 1 - xsi]) + return np.array([eta, xsi, 1 - eta, 1 - xsi]), self.ravel_index(zi, yi, xi) yi, xi = None, None if ei is not None: @@ -200,7 +200,7 @@ def search(self, z, y, x, ei=None, search2D=False): if ds.lon.ndim == 2: eta, xsi, yi, xi = _search_indices_curvilinear_2d(self, y, x, yi, xi) - return self.ravel_index(zi, yi, xi), np.array([eta, xsi, 1 - eta, 1 - xsi]) + return np.array([eta, xsi, 1 - eta, 1 - xsi]), self.ravel_index(zi, yi, xi) raise NotImplementedError("Searching in >2D lon/lat arrays is not implemented yet.") diff --git a/tests/v4/test_xgrid.py b/tests/v4/test_xgrid.py index fd2da4b90e..05d2f463de 100644 --- a/tests/v4/test_xgrid.py +++ b/tests/v4/test_xgrid.py @@ -157,7 +157,7 @@ def test_xgrid_search_cpoints(ds): for xi in range(grid.xdim - 1): for yi in range(grid.ydim - 1): lat, lon = lat_array[yi, xi], lon_array[yi, xi] - ei, bcoords = grid.search(0, lat, lon, ei=None, search2D=True) + bcoords, ei = grid.search(0, lat, lon, ei=None, search2D=True) zi_test, yi_test, xi_test = grid.unravel_index(ei) assert xi == xi_test assert yi == yi_test From 2e5b593eeec202f72efbb4529e3584b7b175cdd4 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Fri, 27 Jun 2025 16:58:23 +0200 Subject: [PATCH 22/24] Update xgrid and basegrid search, ravel, unravel API --- parcels/basegrid.py | 100 +++++++++++++++++++++++++++++++++++++---- parcels/xgrid.py | 75 ++++++++++--------------------- tests/v4/test_xgrid.py | 19 ++++---- 3 files changed, 125 insertions(+), 69 deletions(-) diff --git a/parcels/basegrid.py b/parcels/basegrid.py index 813a7ced7d..690ffaa675 100644 --- a/parcels/basegrid.py +++ b/parcels/basegrid.py @@ -1,15 +1,21 @@ +from __future__ import annotations + from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import numpy as np class BaseGrid(ABC): @abstractmethod - def search(self, z: float, y: float, x: float, ei=None, search2D: bool = False): + def search(self, z: float, y: float, x: float, ei=None) -> dict[str, tuple[int, float | np.ndarray]]: """ Perform a spatial (and optionally vertical) search to locate the grid element that contains a given point (x, y, z). This method delegates to grid-type-specific logic (e.g., structured or unstructured) - to determine the appropriate indices and interpolation coordinates for evaluating a field. + to determine the appropriate indices and barycentric coordinates for evaluating a field. Parameters ---------- @@ -28,12 +34,21 @@ def search(self, z: float, y: float, x: float, ei=None, search2D: bool = False): Returns ------- - bcoords : np.ndarray or tuple - Interpolation weights or barycentric coordinates within the containing cell/face. - The interpretation of `bcoords` depends on the grid type. - ei : int - Encoded index of the identified grid cell or face. This value can be cached for - future lookups to accelerate repeated searches. + dict + A dictionary mapping spatial axis names to tuples of (index, barycentric_coordinates). + The returned axes depend on the grid dimensionality and type: + + - 3D structured grid: {"X": (xi, xsi), "Y": (yi, eta), "Z": (zi, zeta)} + - 2D structured grid: {"X": (xi, xsi), "Y": (yi, eta)} + - 1D structured grid (depth): {"Z": (zi, zeta)} + - Unstructured grid: {"Z": (zi, zeta), "FACE": (fi, bcoords)} + + Where: + - index (int): The cell position of a particle along the given axis + - barycentric_coordinates (float or np.ndarray): The coordinates defining + a particle's position within the grid cell. For structured grids, this + is a single coordinate per axis; for unstructured grids, this can be + an array of coordinates for the face polygon. Raises ------ @@ -43,3 +58,72 @@ def search(self, z: float, y: float, x: float, ei=None, search2D: bool = False): Raised if the search method is not implemented for the current grid type. """ ... + + @abstractmethod + def ravel_index(self, axis_indices: dict[str, int]) -> int: + """ + Convert a dictionary of axis indices to a single encoded index (ei). + + This method takes the individual indices for each spatial axis and combines them + into a single integer that uniquely identifies a grid cell. This encoded + index can be used for efficient caching and lookup operations. + + Parameters + ---------- + axis_indices : dict[str, int] + A dictionary mapping axis names to their corresponding indices. + The expected keys depend on the grid dimensionality and type: + + - 3D structured grid: {"X": xi, "Y": yi, "Z": zi} + - 2D structured grid: {"X": xi, "Y": yi} + - 1D structured grid: {"Z": zi} + - Unstructured grid: {"Z": zi, "FACE": fi} + + Returns + ------- + int + The encoded index (ei) representing the unique grid cell or face. + + Raises + ------ + KeyError + Raised when required axis keys are missing from axis_indices. + ValueError + Raised when index values are out of bounds for the grid. + NotImplementedError + Raised if the method is not implemented for the current grid type. + """ + ... + + @abstractmethod + def unravel_index(self, ei: int) -> dict[str, int]: + """ + Convert a single encoded index (ei) back to a dictionary of axis indices. + + This method is the inverse of ravel_index, taking an encoded index and + decomposing it back into the individual indices for each spatial axis. + + Parameters + ---------- + ei : int + The encoded index representing a unique grid cell or face. + + Returns + ------- + dict[str, int] + A dictionary mapping axis names to their corresponding indices. + The returned keys depend on the grid dimensionality and type: + + - 3D structured grid: {"X": xi, "Y": yi, "Z": zi} + - 2D structured grid: {"X": xi, "Y": yi} + - 1D structured grid: {"Z": zi} + - Unstructured grid: {"Z": zi, "FACE": fi} + + Raises + ------ + ValueError + Raised when the encoded index is out of bounds or invalid for the grid. + NotImplementedError + Raised if the method is not implemented for the current grid type. + """ + ... diff --git a/parcels/xgrid.py b/parcels/xgrid.py index fb42d68c1b..fd131c0df8 100644 --- a/parcels/xgrid.py +++ b/parcels/xgrid.py @@ -10,9 +10,10 @@ from parcels.basegrid import BaseGrid from parcels.tools.converters import TimeConverter -_AXIS_DIRECTION = Literal["X", "Y", "Z", "T"] -_AXIS_POSITION = Literal["center", "left", "right", "inner", "outer"] -_XGCM_AXES = Mapping[_AXIS_DIRECTION, xgcm.Axis] +_XGCM_AXIS_DIRECTION = Literal["X", "Y", "Z", "T"] +_XGCM_AXIS_POSITION = Literal["center", "left", "right", "inner", "outer"] +_AXIS_DIRECTION = Literal["X", "Y", "Z"] +_XGCM_AXES = Mapping[_XGCM_AXIS_DIRECTION, xgcm.Axis] def get_tracer_dimensionality(axis: xgcm.Axis | None) -> int: @@ -180,77 +181,49 @@ def _gtype(self): else: return GridType.CurvilinearSGrid - def search(self, z, y, x, ei=None, search2D=False): + def search(self, z, y, x, ei=None): ds = self.xgcm_grid._ds - if search2D: - zi = 0 - else: - zi, _ = _search_1d_array(ds.depth.values, z) + zi, zeta = _search_1d_array(ds.depth.values, z) if ds.lon.ndim == 1: yi, eta = _search_1d_array(ds.lat.values, y) xi, xsi = _search_1d_array(ds.lon.values, x) - return np.array([eta, xsi, 1 - eta, 1 - xsi]), self.ravel_index(zi, yi, xi) + return {"X": (xi, xsi), "Y": (yi, eta), "Z": (zi, zeta)} yi, xi = None, None if ei is not None: - _, yi, xi = self.unravel_index(ei) + axis_indices = self.unravel_index(ei) + xi = axis_indices.get("X") + yi = axis_indices.get("Y") if ds.lon.ndim == 2: eta, xsi, yi, xi = _search_indices_curvilinear_2d(self, y, x, yi, xi) - return np.array([eta, xsi, 1 - eta, 1 - xsi]), self.ravel_index(zi, yi, xi) + return {"X": (xi, xsi), "Y": (yi, eta), "Z": (zi, zeta)} raise NotImplementedError("Searching in >2D lon/lat arrays is not implemented yet.") - def ravel_index(self, zi, yi, xi): - """ - Converts a z, y, and x index into a single encoded index. - - Parameters - ---------- - zi : int - Vertical index. - yi : int - Latitude index. - xi : int - Longitude index. - - Returns - ------- - int - Encoded index. - """ + def ravel_index(self, axis_indices: dict[_AXIS_DIRECTION, int]) -> int: + xi = axis_indices.get("X", 0) + yi = axis_indices.get("Y", 0) + zi = axis_indices.get("Z", 0) return xi + self.xdim * yi + self.xdim * self.ydim * zi - def unravel_index(self, ei): - """ - Converts a single encoded index back into a Z, Y, and X indices. - - Parameters - ---------- - ei : int - Encoded index to be unraveled. - - Returns - ------- - zi : int - Vertical index. - yi : int - Latitude index. - xi : int - Longitude index. - """ + def unravel_index(self, ei) -> dict[_AXIS_DIRECTION, int]: zi = ei // (self.xdim * self.ydim) ei = ei % (self.xdim * self.ydim) yi = ei // self.xdim xi = ei % self.xdim - return zi, yi, xi + return { + "X": xi, + "Y": yi, + "Z": zi, + } -def get_axis_from_dim_name(axes: _XGCM_AXES, dim: str) -> _AXIS_DIRECTION | None: +def get_axis_from_dim_name(axes: _XGCM_AXES, dim: str) -> _XGCM_AXIS_DIRECTION | None: """For a given dimension name in a grid, returns the direction axis it is on.""" for axis_name, axis in axes.items(): if dim in axis.coords.values(): @@ -258,7 +231,7 @@ def get_axis_from_dim_name(axes: _XGCM_AXES, dim: str) -> _AXIS_DIRECTION | None return None -def get_position_from_dim_name(axes: _XGCM_AXES, dim: str) -> _AXIS_POSITION | None: +def get_position_from_dim_name(axes: _XGCM_AXES, dim: str) -> _XGCM_AXIS_POSITION | None: """For a given dimension, returns the position of the variable in the grid.""" for axis in axes.values(): var_to_position = {var: position for position, var in axis.coords.items()} @@ -287,7 +260,7 @@ def assert_valid_field_array(da: xr.DataArray, axes: _XGCM_AXES): assert_all_dimensions_correspond_with_axis(da, axes) dim_to_axis = {dim: get_axis_from_dim_name(axes, dim) for dim in da.dims} - dim_to_axis = cast(dict[Hashable, _AXIS_DIRECTION], dim_to_axis) + dim_to_axis = cast(dict[Hashable, _XGCM_AXIS_DIRECTION], dim_to_axis) # Assert all dimensions are present if set(dim_to_axis.values()) != {"T", "Z", "Y", "X"}: diff --git a/tests/v4/test_xgrid.py b/tests/v4/test_xgrid.py index 05d2f463de..91806706eb 100644 --- a/tests/v4/test_xgrid.py +++ b/tests/v4/test_xgrid.py @@ -129,11 +129,10 @@ def test_xgrid_ravel_unravel_index(): for xi in range(xdim): for yi in range(ydim): for zi in range(zdim): - ei = grid.ravel_index(zi, yi, xi) - zi_test, yi_test, xi_test = grid.unravel_index(ei) - assert xi == xi_test, f"Expected xi {xi} but got {xi_test} for ei {ei}" - assert yi == yi_test, f"Expected yi {yi} but got {yi_test} for ei {ei}" - assert zi == zi_test, f"Expected zi {zi} but got {zi_test} for ei {ei}" + axis_indices = {"X": xi, "Y": yi, "Z": zi} + ei = grid.ravel_index(axis_indices) + axis_indices_test = grid.unravel_index(ei) + assert axis_indices_test == axis_indices encountered_eis.append(ei) encountered_eis = sorted(encountered_eis) @@ -156,12 +155,12 @@ def test_xgrid_search_cpoints(ds): for xi in range(grid.xdim - 1): for yi in range(grid.ydim - 1): + axis_indices = {"X": xi, "Y": yi, "Z": 0} + lat, lon = lat_array[yi, xi], lon_array[yi, xi] - bcoords, ei = grid.search(0, lat, lon, ei=None, search2D=True) - zi_test, yi_test, xi_test = grid.unravel_index(ei) - assert xi == xi_test - assert yi == yi_test - assert zi_test == 0 + axis_indices_bcoords = grid.search(0, lat, lon, ei=None) + axis_indices_test = {k: v[0] for k, v in axis_indices_bcoords.items()} + assert axis_indices == axis_indices_test # assert np.isclose(bcoords[0], 0.5) #? Should this not be the case with the cell center points? # assert np.isclose(bcoords[1], 0.5) From de0a61640ac7adcf90f2839cb6ac75e46a49f7ff Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Mon, 30 Jun 2025 10:29:33 +0200 Subject: [PATCH 23/24] Update XGrid.{xdim,ydim,zdim,tdim} code documentation and helper --- parcels/xgrid.py | 26 +++++++------------------- 1 file changed, 7 insertions(+), 19 deletions(-) diff --git a/parcels/xgrid.py b/parcels/xgrid.py index fd131c0df8..4947a03750 100644 --- a/parcels/xgrid.py +++ b/parcels/xgrid.py @@ -16,22 +16,13 @@ _XGCM_AXES = Mapping[_XGCM_AXIS_DIRECTION, xgcm.Axis] -def get_tracer_dimensionality(axis: xgcm.Axis | None) -> int: +def get_n_cell_edges_along_dim(axis: xgcm.Axis | None) -> int: if axis is None: return 1 first_coord = list(axis.coords.items())[0] - pos, coord = first_coord + _, coord_var = first_coord - pos_to_dim = { # TODO: These could do with being explicitly tested - "center": lambda x: x, - "left": lambda x: x, - "right": lambda x: x, - "inner": lambda x: x + 1, - "outer": lambda x: x - 1, - } - - n = axis._ds[coord].size - return pos_to_dim[pos](n) + return axis._ds[coord_var].size def get_time(axis: xgcm.Axis) -> npt.NDArray: @@ -121,22 +112,19 @@ def time(self): @property def xdim(self): - """Number of T (tracer) cells in the X direction.""" - return get_tracer_dimensionality(self.xgcm_grid.axes.get("X")) + return get_n_cell_edges_along_dim(self.xgcm_grid.axes.get("X")) @property def ydim(self): - """Number of T (tracer) cells in the Y direction.""" - return get_tracer_dimensionality(self.xgcm_grid.axes.get("Y")) + return get_n_cell_edges_along_dim(self.xgcm_grid.axes.get("Y")) @property def zdim(self): - """Number of T (tracer) cells in the Z direction.""" - return get_tracer_dimensionality(self.xgcm_grid.axes.get("Z")) + return get_n_cell_edges_along_dim(self.xgcm_grid.axes.get("Z")) @property def tdim(self): - return get_tracer_dimensionality(self.xgcm_grid.axes.get("T")) + return get_n_cell_edges_along_dim(self.xgcm_grid.axes.get("T")) @property def time_origin(self): From 2063abd789573bf52bd8036b7d85a849f892854c Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Tue, 1 Jul 2025 11:56:35 +0200 Subject: [PATCH 24/24] Review feedback --- parcels/_index_search.py | 2 +- parcels/basegrid.py | 12 ++++++------ parcels/xgrid.py | 34 +++++++++++++++------------------- tests/v4/test_index_search.py | 2 +- tests/v4/test_xgrid.py | 4 ++-- 5 files changed, 25 insertions(+), 29 deletions(-) diff --git a/parcels/_index_search.py b/parcels/_index_search.py index 2d43da79b8..ca82359451 100644 --- a/parcels/_index_search.py +++ b/parcels/_index_search.py @@ -352,7 +352,7 @@ def _search_indices_curvilinear_2d( if not ((0 <= xsi <= 1) and (0 <= eta <= 1)): _raise_field_sampling_error(y, x) - return (eta, xsi, yi, xi) + return (yi, eta, xi, xsi) ## TODO : Still need to implement the search_indices_curvilinear diff --git a/parcels/basegrid.py b/parcels/basegrid.py index 690ffaa675..a7ebcfb42a 100644 --- a/parcels/basegrid.py +++ b/parcels/basegrid.py @@ -38,8 +38,8 @@ def search(self, z: float, y: float, x: float, ei=None) -> dict[str, tuple[int, A dictionary mapping spatial axis names to tuples of (index, barycentric_coordinates). The returned axes depend on the grid dimensionality and type: - - 3D structured grid: {"X": (xi, xsi), "Y": (yi, eta), "Z": (zi, zeta)} - - 2D structured grid: {"X": (xi, xsi), "Y": (yi, eta)} + - 3D structured grid: {"Z": (zi, zeta), "Y": (yi, eta), "X": (xi, xsi)} + - 2D structured grid: {"Y": (yi, eta), "X": (xi, xsi)} - 1D structured grid (depth): {"Z": (zi, zeta)} - Unstructured grid: {"Z": (zi, zeta), "FACE": (fi, bcoords)} @@ -74,8 +74,8 @@ def ravel_index(self, axis_indices: dict[str, int]) -> int: A dictionary mapping axis names to their corresponding indices. The expected keys depend on the grid dimensionality and type: - - 3D structured grid: {"X": xi, "Y": yi, "Z": zi} - - 2D structured grid: {"X": xi, "Y": yi} + - 3D structured grid: {"Z": zi, "Y": yi, "X": xi} + - 2D structured grid: {"Y": yi, "X": xi} - 1D structured grid: {"Z": zi} - Unstructured grid: {"Z": zi, "FACE": fi} @@ -114,8 +114,8 @@ def unravel_index(self, ei: int) -> dict[str, int]: A dictionary mapping axis names to their corresponding indices. The returned keys depend on the grid dimensionality and type: - - 3D structured grid: {"X": xi, "Y": yi, "Z": zi} - - 2D structured grid: {"X": xi, "Y": yi} + - 3D structured grid: {"Z": zi, "Y": yi, "X": xi} + - 2D structured grid: {"Y": yi, "X": xi} - 1D structured grid: {"Z": zi} - Unstructured grid: {"Z": zi, "FACE": fi} diff --git a/parcels/xgrid.py b/parcels/xgrid.py index 4947a03750..c0bb066513 100644 --- a/parcels/xgrid.py +++ b/parcels/xgrid.py @@ -16,7 +16,7 @@ _XGCM_AXES = Mapping[_XGCM_AXIS_DIRECTION, xgcm.Axis] -def get_n_cell_edges_along_dim(axis: xgcm.Axis | None) -> int: +def get_cell_edge_count_along_dim(axis: xgcm.Axis | None) -> int: if axis is None: return 1 first_coord = list(axis.coords.items())[0] @@ -36,7 +36,7 @@ class XGrid(BaseGrid): This class provides methods and properties required for indexing and interpolating on the grid. Assumptions: - - If using Parcels in the context of a periodic simulation, the provided grid already has a halo + - If using Parcels in the context of a spatially periodic simulation, the provided grid already has a halo """ @@ -112,19 +112,19 @@ def time(self): @property def xdim(self): - return get_n_cell_edges_along_dim(self.xgcm_grid.axes.get("X")) + return get_cell_edge_count_along_dim(self.xgcm_grid.axes.get("X")) @property def ydim(self): - return get_n_cell_edges_along_dim(self.xgcm_grid.axes.get("Y")) + return get_cell_edge_count_along_dim(self.xgcm_grid.axes.get("Y")) @property def zdim(self): - return get_n_cell_edges_along_dim(self.xgcm_grid.axes.get("Z")) + return get_cell_edge_count_along_dim(self.xgcm_grid.axes.get("Z")) @property def tdim(self): - return get_n_cell_edges_along_dim(self.xgcm_grid.axes.get("T")) + return get_cell_edge_count_along_dim(self.xgcm_grid.axes.get("T")) @property def time_origin(self): @@ -177,7 +177,7 @@ def search(self, z, y, x, ei=None): if ds.lon.ndim == 1: yi, eta = _search_1d_array(ds.lat.values, y) xi, xsi = _search_1d_array(ds.lon.values, x) - return {"X": (xi, xsi), "Y": (yi, eta), "Z": (zi, zeta)} + return {"Z": (zi, zeta), "Y": (yi, eta), "X": (xi, xsi)} yi, xi = None, None if ei is not None: @@ -186,9 +186,9 @@ def search(self, z, y, x, ei=None): yi = axis_indices.get("Y") if ds.lon.ndim == 2: - eta, xsi, yi, xi = _search_indices_curvilinear_2d(self, y, x, yi, xi) + yi, eta, xi, xsi = _search_indices_curvilinear_2d(self, y, x, yi, xi) - return {"X": (xi, xsi), "Y": (yi, eta), "Z": (zi, zeta)} + return {"Z": (zi, zeta), "Y": (yi, eta), "X": (xi, xsi)} raise NotImplementedError("Searching in >2D lon/lat arrays is not implemented yet.") @@ -204,11 +204,7 @@ def unravel_index(self, ei) -> dict[_AXIS_DIRECTION, int]: yi = ei // self.xdim xi = ei % self.xdim - return { - "X": xi, - "Y": yi, - "Z": zi, - } + return {"Z": zi, "Y": yi, "X": xi} def get_axis_from_dim_name(axes: _XGCM_AXES, dim: str) -> _XGCM_AXIS_DIRECTION | None: @@ -219,7 +215,7 @@ def get_axis_from_dim_name(axes: _XGCM_AXES, dim: str) -> _XGCM_AXIS_DIRECTION | return None -def get_position_from_dim_name(axes: _XGCM_AXES, dim: str) -> _XGCM_AXIS_POSITION | None: +def get_xgcm_position_from_dim_name(axes: _XGCM_AXES, dim: str) -> _XGCM_AXIS_POSITION | None: """For a given dimension, returns the position of the variable in the grid.""" for axis in axes.values(): var_to_position = {var: position for position, var in axis.coords.items()} @@ -281,16 +277,16 @@ def assert_valid_lat_lon(da_lat, da_lon, axes: _XGCM_AXES): assert_all_dimensions_correspond_with_axis(da_lon, axes) assert_all_dimensions_correspond_with_axis(da_lat, axes) - dim_to_position = {dim: get_position_from_dim_name(axes, dim) for dim in da_lon.dims} - dim_to_position.update({dim: get_position_from_dim_name(axes, dim) for dim in da_lat.dims}) + dim_to_position = {dim: get_xgcm_position_from_dim_name(axes, dim) for dim in da_lon.dims} + dim_to_position.update({dim: get_xgcm_position_from_dim_name(axes, dim) for dim in da_lat.dims}) for dim in da_lon.dims: - if get_position_from_dim_name(axes, dim) == "center": + if get_xgcm_position_from_dim_name(axes, dim) == "center": raise ValueError( f"Longitude DataArray {da_lon.name!r} with dims {da_lon.dims} is defined on the center of the grid, but must be defined on the F points." ) for dim in da_lat.dims: - if get_position_from_dim_name(axes, dim) == "center": + if get_xgcm_position_from_dim_name(axes, dim) == "center": raise ValueError( f"Latitude DataArray {da_lat.name!r} with dims {da_lat.dims} is defined on the center of the grid, but must be defined on the F points." ) diff --git a/tests/v4/test_index_search.py b/tests/v4/test_index_search.py index f39df131d1..7f9290f12c 100644 --- a/tests/v4/test_index_search.py +++ b/tests/v4/test_index_search.py @@ -30,7 +30,7 @@ def test_grid_indexing_fpoints(field_cone): x = grid.lon[yi_expected, xi_expected] + 0.00001 y = grid.lat[yi_expected, xi_expected] + 0.00001 - eta, xsi, yi, xi = _search_indices_curvilinear_2d(grid, y, x) + yi, eta, xi, xsi = _search_indices_curvilinear_2d(grid, y, x) if eta > 0.9: yi_expected -= 1 if xsi > 0.9: diff --git a/tests/v4/test_xgrid.py b/tests/v4/test_xgrid.py index 91806706eb..39a5ee614d 100644 --- a/tests/v4/test_xgrid.py +++ b/tests/v4/test_xgrid.py @@ -129,7 +129,7 @@ def test_xgrid_ravel_unravel_index(): for xi in range(xdim): for yi in range(ydim): for zi in range(zdim): - axis_indices = {"X": xi, "Y": yi, "Z": zi} + axis_indices = {"Z": zi, "Y": yi, "X": xi} ei = grid.ravel_index(axis_indices) axis_indices_test = grid.unravel_index(ei) assert axis_indices_test == axis_indices @@ -155,7 +155,7 @@ def test_xgrid_search_cpoints(ds): for xi in range(grid.xdim - 1): for yi in range(grid.ydim - 1): - axis_indices = {"X": xi, "Y": yi, "Z": 0} + axis_indices = {"Z": 0, "Y": yi, "X": xi} lat, lon = lat_array[yi, xi], lon_array[yi, xi] axis_indices_bcoords = grid.search(0, lat, lon, ei=None)