Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 24 additions & 6 deletions parcels/_datasets/unstructured/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ def _stommel_gyre_delaunay():
lon, lat = np.meshgrid(np.linspace(0, 60.0, Nx, dtype=np.float32), np.linspace(0, 60.0, Nx, dtype=np.float32))
lon_flat = lon.ravel()
lat_flat = lat.ravel()
zf = np.linspace(0.0, 1000.0, 2, endpoint=True, dtype=np.float32) # Vertical element faces
zc = 0.5 * (zf[:-1] + zf[1:]) # Vertical element centers
nz = zf.size
nz1 = zc.size

# mask any point on one of the boundaries
mask = (
Expand All @@ -40,9 +44,10 @@ def _stommel_gyre_delaunay():
uxgrid.attrs["Conventions"] = "UGRID-1.0"

# Define arrays U (zonal), V (meridional) and P (sea surface height)
U = np.zeros((1, 1, lat.size), dtype=np.float64)
V = np.zeros((1, 1, lat.size), dtype=np.float64)
P = np.zeros((1, 1, lat.size), dtype=np.float64)
U = np.zeros((1, nz1, lat.size), dtype=np.float64)
V = np.zeros((1, nz1, lat.size), dtype=np.float64)
W = np.zeros((1, nz, lat.size), dtype=np.float64)
P = np.zeros((1, nz1, lat.size), dtype=np.float64)

for i, (x, y) in enumerate(zip(lon_flat, lat_flat, strict=False)):
xi = x / 60.0
Expand Down Expand Up @@ -72,7 +77,20 @@ def _stommel_gyre_delaunay():
dims=["time", "nz1", "n_node"],
coords=dict(
time=(["time"], [TIME[0]]),
nz1=(["nz1"], [0]),
nz1=(["nz1"], zc),
),
attrs=dict(
description="meridional velocity", units="m/s", location="node", mesh="delaunay", Conventions="UGRID-1.0"
),
)
w = ux.UxDataArray(
data=W,
name="W",
uxgrid=uxgrid,
dims=["time", "nz", "n_node"],
coords=dict(
time=(["time"], [TIME[0]]),
nz=(["nz"], zf),
),
attrs=dict(
description="meridional velocity", units="m/s", location="node", mesh="delaunay", Conventions="UGRID-1.0"
Expand All @@ -85,12 +103,12 @@ def _stommel_gyre_delaunay():
dims=["time", "nz1", "n_node"],
coords=dict(
time=(["time"], [TIME[0]]),
nz1=(["nz1"], [0]),
nz1=(["nz1"], zc),
),
attrs=dict(description="pressure", units="N/m^2", location="node", mesh="delaunay", Conventions="UGRID-1.0"),
)

return ux.UxDataset({"U": u, "V": v, "p": p}, uxgrid=uxgrid)
return ux.UxDataset({"U": u, "V": v, "W": w, "p": p}, uxgrid=uxgrid)


def _fesom2_square_delaunay_uniform_z_coordinate():
Expand Down
28 changes: 18 additions & 10 deletions parcels/application_kernels/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,8 @@ def UXPiecewiseConstantFace(
This interpolation method is appropriate for fields that are
face registered, such as u,v in FESOM.
"""
# TODO joe : handle vertical interpolation
zi, fi = field.unravel_index(ei)
return field.data[ti, zi, fi]
zi, fi = field.grid.unravel_index(ei)
return field.data.values[ti, zi, fi]


def UXPiecewiseLinearNode(
Expand All @@ -43,11 +42,20 @@ def UXPiecewiseLinearNode(
x: np.float32 | np.float64,
):
"""
Piecewise linear interpolation kernel for node registered data. This
interpolation method is appropriate for fields that are node registered
such as the vertical velocity w in FESOM.
Piecewise linear interpolation kernel for node registered data located at vertical interface levels.
This interpolation method is appropriate for fields that are node registered such as the vertical
velocity W in FESOM2. Effectively, it applies barycentric interpolation in the lateral direction
and piecewise linear interpolation in the vertical direction.
"""
# TODO joe : handle vertical interpolation
zi, fi = field.unravel_index(ei)
node_ids = field.data.uxgrid.face_node_connectivity[fi, :]
return np.dot(field.data[ti, zi, node_ids], bcoords)
k, fi = field.grid.unravel_index(ei)
node_ids = field.grid.uxgrid.face_node_connectivity[fi, :]
# The zi refers to the vertical layer index. The field in this routine are assumed to be defined at the vertical interface levels.
# For interface zi, the interface indices are [zi, zi+1], so we need to use the values at zi and zi+1.
# First, do barycentric interpolation in the lateral direction for each interface level
fzk = np.dot(field.data.values[ti, k, node_ids], bcoords)
fzkp1 = np.dot(field.data.values[ti, k + 1, node_ids], bcoords)

# Then, do piecewise linear interpolation in the vertical direction
zk = field.grid.z.values[k]
zkp1 = field.grid.z.values[k + 1]
return (fzk * (zkp1 - z) + fzkp1 * (z - zk)) / (zkp1 - zk) # Linear interpolation in the vertical direction
2 changes: 1 addition & 1 deletion parcels/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ def eval(self, time: datetime, z, y, x, particle=None, applyConversion=True):

try:
tau, ti = _search_time_index(self, time)
bcoords, _ei = self.grid.search(self, z, y, x, ei=_ei)
bcoords, _ei = self.grid.search(z, y, x, ei=_ei)
value = self._interp_method(self, ti, _ei, bcoords, tau, time, z, y, x)

if np.isnan(value):
Expand Down
51 changes: 41 additions & 10 deletions parcels/uxgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,41 +15,72 @@
for interpolation on unstructured grids.
"""

def __init__(self, grid: ux.grid.Grid) -> UxGrid:
def __init__(self, grid: ux.grid.Grid, z: ux.UxDataArray) -> UxGrid:
"""
Initializes the UxGrid with a uxarray grid and vertical coordinate array.

Parameters
----------
grid : ux.grid.Grid
The uxarray grid object containing the unstructured grid data.
z : ux.UxDataArray
A 1D array of vertical coordinates (depths) associated with the layer interface heights (not the mid-layer depths).
While uxarray allows nz to be spatially and temporally varying, the parcels.UxGrid class considers the case where
the vertical coordinate is constant in time and space. This implies flat bottom topography and no moving ALE vertical grid.
"""
self.uxgrid = grid
if not isinstance(z, ux.UxDataArray):
raise TypeError("z must be an instance of ux.UxDataArray")

Check warning on line 33 in parcels/uxgrid.py

View check run for this annotation

Codecov / codecov/patch

parcels/uxgrid.py#L33

Added line #L33 was not covered by tests
if z.ndim != 1:
raise ValueError("z must be a 1D array of vertical coordinates")

Check warning on line 35 in parcels/uxgrid.py

View check run for this annotation

Codecov / codecov/patch

parcels/uxgrid.py#L35

Added line #L35 was not covered by tests
self.z = z

def search(
self, z: float, y: float, x: float, ei: int | None = None, search2D: bool = False
) -> tuple[np.ndarray, int]:
tol = 1e-10

def try_face(fid):
# TODO : Vertical search is not implemented yet, so we assume z is not used.
bcoords, err = self.uxgrid._get_barycentric_coordinates(y, x, fid)
if (bcoords >= 0).all() and (bcoords <= 1).all() and err < tol:
return bcoords, self.ravel_index(0, fid) # Z and time indices are 0 for now
return bcoords, fid

Check warning on line 46 in parcels/uxgrid.py

View check run for this annotation

Codecov / codecov/patch

parcels/uxgrid.py#L46

Added line #L46 was not covered by tests
return None, None

def find_vertical_index() -> int:
if search2D:
return 0

Check warning on line 51 in parcels/uxgrid.py

View check run for this annotation

Codecov / codecov/patch

parcels/uxgrid.py#L51

Added line #L51 was not covered by tests
else:
nz = self.z.shape[0]
if nz == 1:
return 0

Check warning on line 55 in parcels/uxgrid.py

View check run for this annotation

Codecov / codecov/patch

parcels/uxgrid.py#L55

Added line #L55 was not covered by tests
zf = self.z.values
# Return zi such that zf[zi] <= z < zf[zi+1]
zi = np.searchsorted(zf, z, side="right") - 1 # Search assumes that z is positive and increasing with i
if zi < 0 or zi >= nz - 1:
raise FieldOutOfBoundError(z, y, x)

Check warning on line 60 in parcels/uxgrid.py

View check run for this annotation

Codecov / codecov/patch

parcels/uxgrid.py#L60

Added line #L60 was not covered by tests
return zi

zi = find_vertical_index() # Find the vertical cell center nearest to z

if ei is not None:
zi, fi = self.unravel_index(ei)
bcoords, ei_new = try_face(fi)
_, fi = self.unravel_index(ei)
bcoords, fi_new = try_face(fi)

Check warning on line 67 in parcels/uxgrid.py

View check run for this annotation

Codecov / codecov/patch

parcels/uxgrid.py#L66-L67

Added lines #L66 - L67 were not covered by tests
if bcoords is not None:
return bcoords, ei_new

return bcoords, self.ravel_index(zi, fi_new)

Check warning on line 69 in parcels/uxgrid.py

View check run for this annotation

Codecov / codecov/patch

parcels/uxgrid.py#L69

Added line #L69 was not covered by tests
# Try neighbors of current face
for neighbor in self.uxgrid.face_face_connectivity[fi, :]:
if neighbor == -1:
continue
bcoords, ei_new = try_face(neighbor)
bcoords, fi_new = try_face(neighbor)

Check warning on line 74 in parcels/uxgrid.py

View check run for this annotation

Codecov / codecov/patch

parcels/uxgrid.py#L74

Added line #L74 was not covered by tests
if bcoords is not None:
return bcoords, ei_new
return bcoords, self.ravel_index(zi, fi_new)

Check warning on line 76 in parcels/uxgrid.py

View check run for this annotation

Codecov / codecov/patch

parcels/uxgrid.py#L76

Added line #L76 was not covered by tests

# Global fallback using spatial hash
fi, bcoords = self.uxgrid.get_spatial_hash().query([[x, y]])
if fi == -1:
raise FieldOutOfBoundError(z, y, x)

return bcoords, self.ravel_index(zi, fi)
return bcoords[0], self.ravel_index(zi, fi[0])

def _get_barycentric_coordinates(self, y, x, fi):
"""Checks if a point is inside a given face id on a UxGrid."""
Expand Down
39 changes: 37 additions & 2 deletions tests/v4/test_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import uxarray as ux
import xarray as xr

from parcels import Field, xgcm
from parcels import Field, UXPiecewiseConstantFace, UXPiecewiseLinearNode, xgcm
from parcels._datasets.structured.generic import T as T_structured
from parcels._datasets.structured.generic import datasets as datasets_structured
from parcels._datasets.unstructured.generic import datasets as datasets_unstructured
Expand Down Expand Up @@ -33,7 +33,10 @@ def test_field_init_param_types():
pytest.param(ux.UxDataArray(), XGrid(xgcm.Grid(datasets_structured["ds_2d_left"])), id="uxdata-grid"),
pytest.param(
xr.DataArray(),
UxGrid(datasets_unstructured["stommel_gyre_delaunay"].uxgrid),
UxGrid(
datasets_unstructured["stommel_gyre_delaunay"].uxgrid,
z=datasets_unstructured["stommel_gyre_delaunay"].coords["nz"],
),
id="xarray-uxgrid",
),
],
Expand Down Expand Up @@ -110,6 +113,38 @@ def test_vectorfield_init_different_time_intervals():
...


def test_field_unstructured_z_linear():
"""Tests correctness of piecewise constant and piecewise linear interpolation methods on an unstructured grid with a vertical coordinate.
The example dataset is a FESOM2 square Delaunay grid with uniform z-coordinate. Cell centered and layer registered data are defined to be
linear functions of the vertical coordinate. This allows for testing of exactness of the interpolation methods.
"""
ds = datasets_unstructured["fesom2_square_delaunay_uniform_z_coordinate"].copy(deep=True)

# Change the pressure values to be linearly dependent on the vertical coordinate
for k, z in enumerate(ds.coords["nz1"]):
ds["p"].values[:, k, :] = z

# Change the vertical velocity values to be linearly dependent on the vertical coordinate
for k, z in enumerate(ds.coords["nz"]):
ds["W"].values[:, k, :] = z

grid = UxGrid(ds.uxgrid, z=ds.coords["nz"])
# Note that the vertical coordinate is required to be the position of the layer interfaces ("nz"), not the mid-layers ("nz1")
P = Field(name="p", data=ds.p, grid=grid, interp_method=UXPiecewiseConstantFace)

# Test above first cell center - for piecewise constant, should return the depth of the first cell center
assert np.isclose(P.eval(time=ds.time[0].values, z=10.0, y=30.0, x=30.0, applyConversion=False), 55.555557)
# Test below first cell center, but in the first layer - for piecewise constant, should return the depth of the first cell center
assert np.isclose(P.eval(time=ds.time[0].values, z=65.0, y=30.0, x=30.0, applyConversion=False), 55.555557)
# Test bottom layer - for piecewise constant, should return the depth of the of the bottom layer cell center
assert np.isclose(P.eval(time=ds.time[0].values, z=900.0, y=30.0, x=30.0, applyConversion=False), 944.44445801)

W = Field(name="W", data=ds.W, grid=grid, interp_method=UXPiecewiseLinearNode)
assert np.isclose(W.eval(time=ds.time[0].values, z=10.0, y=30.0, x=30.0, applyConversion=False), 10.0)
assert np.isclose(W.eval(time=ds.time[0].values, z=65.0, y=30.0, x=30.0, applyConversion=False), 65.0)
assert np.isclose(W.eval(time=ds.time[0].values, z=900.0, y=30.0, x=30.0, applyConversion=False), 900.0)


def test_field_unstructured_grid_creation(): ...


Expand Down
26 changes: 14 additions & 12 deletions tests/v4/test_uxarray_fieldset.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,13 @@ def uv_fesom_channel(ds_fesom_channel) -> VectorField:
U=Field(
name="U",
data=ds_fesom_channel.U,
grid=UxGrid(ds_fesom_channel.uxgrid),
grid=UxGrid(ds_fesom_channel.uxgrid, z=ds_fesom_channel.coords["nz"]),
interp_method=UXPiecewiseConstantFace,
),
V=Field(
name="V",
data=ds_fesom_channel.V,
grid=UxGrid(ds_fesom_channel.uxgrid),
grid=UxGrid(ds_fesom_channel.uxgrid, z=ds_fesom_channel.coords["nz"]),
interp_method=UXPiecewiseConstantFace,
),
)
Expand All @@ -57,17 +57,20 @@ def uvw_fesom_channel(ds_fesom_channel) -> VectorField:
U=Field(
name="U",
data=ds_fesom_channel.U,
grid=UxGrid(ds_fesom_channel.uxgrid),
grid=UxGrid(ds_fesom_channel.uxgrid, z=ds_fesom_channel.coords["nz"]),
interp_method=UXPiecewiseConstantFace,
),
V=Field(
name="V",
data=ds_fesom_channel.V,
grid=UxGrid(ds_fesom_channel.uxgrid),
grid=UxGrid(ds_fesom_channel.uxgrid, z=ds_fesom_channel.coords["nz"]),
interp_method=UXPiecewiseConstantFace,
),
W=Field(
name="W", data=ds_fesom_channel.W, grid=UxGrid(ds_fesom_channel.uxgrid), interp_method=UXPiecewiseLinearNode
name="W",
data=ds_fesom_channel.W,
grid=UxGrid(ds_fesom_channel.uxgrid, z=ds_fesom_channel.coords["nz"]),
interp_method=UXPiecewiseLinearNode,
),
)
return UVW
Expand Down Expand Up @@ -114,7 +117,6 @@ def test_fesom_channel(ds_fesom_channel, uvw_fesom_channel):
pset.execute(endtime=timedelta(days=1), dt=timedelta(hours=1))


@pytest.mark.xfail(reason="https://github.com/OceanParcels/Parcels/pull/2026#issuecomment-2945609874") # TODO: Fix
def test_fesom2_square_delaunay_uniform_z_coordinate_eval():
"""
Test the evaluation of a fieldset with a FESOM2 square Delaunay grid and uniform z-coordinate.
Expand All @@ -124,14 +126,14 @@ def test_fesom2_square_delaunay_uniform_z_coordinate_eval():
ds = datasets_unstructured["fesom2_square_delaunay_uniform_z_coordinate"]
UVW = VectorField(
name="UVW",
U=Field(name="U", data=ds.U, grid=UxGrid(ds.uxgrid), interp_method=UXPiecewiseConstantFace),
V=Field(name="V", data=ds.V, grid=UxGrid(ds.uxgrid), interp_method=UXPiecewiseConstantFace),
W=Field(name="W", data=ds.W, grid=UxGrid(ds.uxgrid), interp_method=UXPiecewiseLinearNode),
U=Field(name="U", data=ds.U, grid=UxGrid(ds.uxgrid, z=ds.coords["nz"]), interp_method=UXPiecewiseConstantFace),
V=Field(name="V", data=ds.V, grid=UxGrid(ds.uxgrid, z=ds.coords["nz"]), interp_method=UXPiecewiseConstantFace),
W=Field(name="W", data=ds.W, grid=UxGrid(ds.uxgrid, z=ds.coords["nz"]), interp_method=UXPiecewiseLinearNode),
)
P = Field(name="p", data=ds.p, grid=UxGrid(ds.uxgrid), interp_method=UXPiecewiseConstantFace)
P = Field(name="p", data=ds.p, grid=UxGrid(ds.uxgrid, z=ds.coords["nz"]), interp_method=UXPiecewiseConstantFace)
fieldset = FieldSet([UVW, P, UVW.U, UVW.V, UVW.W])

assert fieldset.U.eval(time=ds.time[0].values, z=1.0, y=30.0, x=30.0, applyConversion=False) == 1.0
assert fieldset.V.eval(time=ds.time[0].values, z=1.0, y=30.0, x=30.0, applyConversion=False) == 1.0
# assert fieldset.W.eval(time=ds.time[0].values, z=1.0, y=30.0, x=30.0, applyConversion=False) == 0.0
assert fieldset.P.eval(time=ds.time[0].values, z=1.0, y=30.0, x=30.0, applyConversion=False) == 1.0
assert fieldset.W.eval(time=ds.time[0].values, z=1.0, y=30.0, x=30.0, applyConversion=False) == 0.0
assert fieldset.p.eval(time=ds.time[0].values, z=1.0, y=30.0, x=30.0, applyConversion=False) == 1.0
Loading