Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions fme/ace/data_loading/perturbation.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,15 @@ def _get_ocean_mask(ocean_fraction: torch.Tensor, cutoff: float = 0.5) -> torch.
class ConstantConfig(PerturbationConfig):
"""
Configuration for a constant perturbation.

Parameters:
amplitude: The amplitude of the perturbation.
ocean_fraction_cutoff: Minimum ocean fraction to apply the perturbation.
If None, the perturbation is applied to all grid points.
"""

amplitude: float = 1.0
ocean_fraction_cutoff: float | None = 0.5

def apply_perturbation(
self,
Expand All @@ -102,8 +108,11 @@ def apply_perturbation(
lon: torch.Tensor,
ocean_fraction: torch.Tensor,
):
ocean_mask = _get_ocean_mask(ocean_fraction)
data[ocean_mask] += self.amplitude # type: ignore
if self.ocean_fraction_cutoff is None:
data += self.amplitude
else:
ocean_mask = _get_ocean_mask(ocean_fraction, self.ocean_fraction_cutoff)
data[ocean_mask] += self.amplitude # type: ignore


@PerturbationSelector.register("greens_function")
Expand All @@ -120,13 +129,16 @@ class GreensFunctionConfig(PerturbationConfig):
lon_center: The longitude at the center of the patch in degrees.
lat_width: latitudinal width of the patch in degrees.
lon_width: longitudinal width of the patch in degrees.
ocean_fraction_cutoff: Minimum ocean fraction to apply the perturbation.
If None, the perturbation is applied to all grid points.
"""

amplitude: float = 1.0
lat_center: float = 0.0
lon_center: float = 0.0
lat_width: float = 10.0
lon_width: float = 10.0
ocean_fraction_cutoff: float | None = 0.5

def __post_init__(self):
self._lat_center_rad = np.deg2rad(self.lat_center)
Expand Down Expand Up @@ -166,7 +178,6 @@ def apply_perturbation(
lat_in_patch = torch.abs(lat - self.lat_center) < self.lat_width / 2.0
lon_in_patch, lon_shifted = self._wrap_longitude_discontinuity(lon)
mask = lat_in_patch & lon_in_patch
ocean_mask = _get_ocean_mask(ocean_fraction)
perturbation = self.amplitude * (
torch.cos(
torch.pi
Expand All @@ -185,4 +196,7 @@ def apply_perturbation(
)
mask = mask.expand(data.shape)
perturbation = perturbation.expand(data.shape)
data[mask & ocean_mask] += perturbation[mask & ocean_mask]
if self.ocean_fraction_cutoff is not None:
ocean_mask = _get_ocean_mask(ocean_fraction, self.ocean_fraction_cutoff)
mask = mask & ocean_mask
data[mask] += perturbation[mask]
34 changes: 34 additions & 0 deletions fme/ace/data_loading/test_perturbation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,40 @@ def test_constant_perturbation_config():
torch.testing.assert_close(data, expected)


def test_constant_perturbation_all_grid_points():
"""With ocean_fraction_cutoff=None, perturbation applies everywhere."""
selector = PerturbationSelector(
type="constant",
config={"amplitude": 2.0, "ocean_fraction_cutoff": None},
)
perturbation = selector.build()
assert isinstance(perturbation, ConstantConfig)
nx, ny = 5, 5
lat = torch.arange(nx, device=fme.get_device())
lon = torch.arange(ny, device=fme.get_device())
lats, lons = torch.meshgrid(lat, lon, indexing="ij")
ocean_fraction = torch.zeros(nx, ny, device=fme.get_device())
data = torch.ones(nx, ny, device=fme.get_device())
perturbation.apply_perturbation(data, lats, lons, ocean_fraction)
expected = 3.0 * torch.ones(nx, ny, device=fme.get_device())
torch.testing.assert_close(data, expected)


def test_constant_perturbation_ocean_only():
"""Default ocean_fraction_cutoff=0.5 skips land points."""
config = ConstantConfig(amplitude=1.0)
nx, ny = 4, 4
lat = torch.arange(nx, device=fme.get_device()).float()
lon = torch.arange(ny, device=fme.get_device()).float()
lats, lons = torch.meshgrid(lat, lon, indexing="ij")
ocean_fraction = torch.zeros(nx, ny, device=fme.get_device())
ocean_fraction[:2, :] = 1.0 # top half is ocean
data = torch.zeros(nx, ny, device=fme.get_device())
config.apply_perturbation(data, lats, lons, ocean_fraction)
assert torch.all(data[:2, :] == 1.0)
assert torch.all(data[2:, :] == 0.0)


def test_green_function_perturbation_config():
selector = PerturbationSelector(
type="greens_function",
Expand Down
44 changes: 44 additions & 0 deletions fme/core/dataset/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .utils import (
_broadcast_array_to_tensor,
_get_indexers,
_zonal_interp_periodic,
accumulate_labels,
as_broadcasted_tensor,
decode_timestep,
Expand Down Expand Up @@ -204,3 +205,46 @@ def test__broadcast_array_to_tensor_raises_assertion_error():
arr = np.zeros(3)
with pytest.raises(ValueError, match="matching time dimension"):
_broadcast_array_to_tensor(arr, (TIME_DIM, LAT_DIM, LON_DIM), (4, 2, 3))


class TestZonalInterpPeriodic:
def test_no_nans_unchanged(self):
arr = np.array([[1.0, 2.0, 3.0, 4.0, 5.0]])
result = _zonal_interp_periodic(arr)
np.testing.assert_array_equal(result, arr)

def test_all_nan_row_unchanged(self):
arr = np.full((1, 6), np.nan)
result = _zonal_interp_periodic(arr)
assert np.all(np.isnan(result))

def test_interior_nan_interpolated(self):
arr = np.array([[1.0, np.nan, 3.0, 4.0, 5.0, 6.0]])
result = _zonal_interp_periodic(arr)
assert not np.any(np.isnan(result))
np.testing.assert_allclose(result[0, 1], 2.0)

def test_periodic_boundary(self):
"""NaNs at the edges should be filled using values wrapped from the
opposite end of the array."""
arr = np.array([[np.nan, 2.0, 3.0, 4.0, 5.0, np.nan]])
result = _zonal_interp_periodic(arr)
assert not np.any(np.isnan(result))
# Padded array is [4, 5, NaN, | NaN, 2, 3, 4, 5, NaN, | NaN, 2, 3].
# Index 0 (padded idx 3) is interpolated between padded (1,5) and
# (4,2) giving 3.0; index 5 (padded idx 8) similarly gives 4.0.
np.testing.assert_allclose(result[0, 0], 3.0)
np.testing.assert_allclose(result[0, -1], 4.0)

def test_3d_array(self):
arr = np.ones((2, 3, 8))
arr[:, :, 2] = np.nan
result = _zonal_interp_periodic(arr)
assert not np.any(np.isnan(result))
np.testing.assert_allclose(result[:, :, 2], 1.0)

def test_original_values_preserved(self):
arr = np.array([[10.0, np.nan, np.nan, 40.0, 50.0, 60.0]])
result = _zonal_interp_periodic(arr)
np.testing.assert_array_equal(result[0, 3:], arr[0, 3:])
np.testing.assert_array_equal(result[0, 0], arr[0, 0])
38 changes: 38 additions & 0 deletions fme/core/dataset/test_xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,6 +797,44 @@ def test_fill_nans(mock_data_fixture, engine, file_pattern, request):
assert torch.all(data["constant_var"][:, 0, 0] == 0)


@pytest.mark.parametrize(
"mock_data_fixture, engine, file_pattern",
[
("mock_monthly_netcdfs_with_nans", "netcdf4", "*.nc"),
("mock_monthly_zarr_with_nans", "zarr", "*.zarr"),
],
)
def test_fill_nans_zonal_interp(mock_data_fixture, engine, file_pattern, request):
mock_data: MockData = request.getfixturevalue(mock_data_fixture)
nan_config = FillNaNsConfig(zonal_interp_variables=["foo"])
config = XarrayDataConfig(
data_path=mock_data.tmpdir,
fill_nans=nan_config,
engine=engine,
file_pattern=file_pattern,
)
names = mock_data.var_names.all_names
dataset = xarray_dataset_constructor(config, names, 2)
data, _, _, _ = dataset[0]
# "foo" NaNs at lon=0 should be interpolated, not filled with constant 0
assert not torch.any(torch.isnan(data["foo"][0, :, 0]))
assert not torch.all(data["foo"][0, :, 0] == 0)
# "bar" also had NaNs but is not in zonal_interp_variables,
# so it should be filled with constant 0
assert torch.all(data["bar"][0, :, 0] == 0)
# constant_var is not in zonal_interp_variables, filled with constant 0
assert torch.all(data["constant_var"][:, 0, 0] == 0)


def test_fill_nans_zonal_interp_healpix_raises():
with pytest.raises(ValueError, match="zonal_interp_variables"):
XarrayDataConfig(
data_path="/unused",
spatial_dimensions="healpix",
fill_nans=FillNaNsConfig(zonal_interp_variables=["sst"]),
)


def test_keep_nans(mock_monthly_netcdfs_with_nans):
config_keep_nan = XarrayDataConfig(data_path=mock_monthly_netcdfs_with_nans.tmpdir)
names = mock_monthly_netcdfs_with_nans.var_names.all_names
Expand Down
45 changes: 45 additions & 0 deletions fme/core/dataset/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,39 @@ def _load_all_variables(
return ds[variables].compute()


def _zonal_interp_periodic(array: np.ndarray, pad_width: int = 3) -> np.ndarray:
"""Fill NaNs via periodic linear interpolation along the last axis.

Wraps the longitude dimension periodically, then uses linear interpolation
to fill NaN gaps. Rows that are entirely NaN are left unchanged.

Args:
array: Array of any shape; interpolation is along the last axis.
pad_width: Number of elements to mirror from each end for periodicity.

Returns:
Copy of array with NaNs filled where possible.
"""
n_lon = array.shape[-1]
left = array[..., -pad_width:]
right = array[..., :pad_width]
padded = np.concatenate([left, array, right], axis=-1)

orig_shape = padded.shape
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit use a different name than "orig" here since it has padding on it

flat = padded.reshape(-1, orig_shape[-1]).copy()

x = np.arange(flat.shape[-1])
for i in range(flat.shape[0]):
row = flat[i]
mask = np.isnan(row)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method is greedy, where does this fall in the call stack relative to where we actually load the tensors? I guess not actually wrong to load data earlier, but a bit confusing in terms of the data model

if mask.any() and not mask.all():
valid = ~mask
flat[i, mask] = np.interp(x[mask], x[valid], row[valid])

result = flat.reshape(orig_shape)
return result[..., pad_width : pad_width + n_lon]


@dataclasses.dataclass
class FillNaNsConfig:
"""
Expand All @@ -153,10 +186,14 @@ class FillNaNsConfig:
Parameters:
method: Type of fill operation. Currently only 'constant' is supported.
value: Value to fill NaNs with.
zonal_interp_variables: Variables to fill via periodic zonal (longitude)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is getting merged it seems like we'd want to add 'zonal_interp` as a method in addition to 'constant' rather than adding a new arg

interpolation before applying the constant fill. Only valid for
latlon spatial dimensions.
"""

method: Literal["constant"] = "constant"
value: float = 0.0
zonal_interp_variables: list[str] = dataclasses.field(default_factory=list)


def load_series_data_zarr_async(
Expand All @@ -173,6 +210,9 @@ def load_series_data_zarr_async(
selection = (time_slice, *nontime_selection)
loaded = _load_all_variables_zarr_async(path, names, selection)
if fill_nans is not None:
for k in fill_nans.zonal_interp_variables:
if k in loaded:
loaded[k] = _zonal_interp_periodic(loaded[k])
for k, v in loaded.items():
loaded[k] = np.nan_to_num(v, nan=fill_nans.value)
arrays = {}
Expand All @@ -195,6 +235,11 @@ def load_series_data(
# Fill NaNs after subsetting time slice to avoid triggering loading all
# data, since we do not use dask.
if fill_nans is not None:
for k in fill_nans.zonal_interp_variables:
if k in loaded:
loaded[k] = loaded[k].copy(
data=_zonal_interp_periodic(loaded[k].values)
)
loaded = loaded.fillna(fill_nans.value)
arrays = {}
for n in names:
Expand Down
15 changes: 14 additions & 1 deletion fme/core/dataset/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from fme.core.dataset.properties import DatasetProperties
from fme.core.dataset.schedule import IntSchedule
from fme.core.dataset.time import RepeatedInterval, TimeSlice
from fme.core.dataset.utils import FillNaNsConfig
from fme.core.dataset.utils import FillNaNsConfig, _zonal_interp_periodic
from fme.core.mask_provider import MaskProvider
from fme.core.stacker import Stacker
from fme.core.typing_ import Slice, TensorDict
Expand Down Expand Up @@ -507,6 +507,15 @@ def __post_init__(self):
)
self.torch_dtype # check it can be retrieved
self._default_file_pattern_check()
if (
self.fill_nans is not None
and self.fill_nans.zonal_interp_variables
and self.spatial_dimensions != "latlon"
):
raise ValueError(
"zonal_interp_variables can only be used with "
"spatial_dimensions='latlon'."
)

@property
def zarr_engine_used(self) -> bool:
Expand Down Expand Up @@ -933,6 +942,10 @@ def get_sample_by_time_slice(self, time_slice: slice) -> DatasetItem:
for name in self._time_invariant_names:
variable = ds[name].variable
if self.fill_nans is not None:
if name in self.fill_nans.zonal_interp_variables:
variable = variable.copy(
data=_zonal_interp_periodic(variable.values)
)
variable = variable.fillna(self.fill_nans.value)
tensors[name] = as_broadcasted_tensor(variable, self.dims, shape)
ds.close()
Expand Down
Loading