-
Notifications
You must be signed in to change notification settings - Fork 41
add zonal interpolation fill for NaNs #1033
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -145,6 +145,39 @@ def _load_all_variables( | |
| return ds[variables].compute() | ||
|
|
||
|
|
||
| def _zonal_interp_periodic(array: np.ndarray, pad_width: int = 3) -> np.ndarray: | ||
| """Fill NaNs via periodic linear interpolation along the last axis. | ||
|
|
||
| Wraps the longitude dimension periodically, then uses linear interpolation | ||
| to fill NaN gaps. Rows that are entirely NaN are left unchanged. | ||
|
|
||
| Args: | ||
| array: Array of any shape; interpolation is along the last axis. | ||
| pad_width: Number of elements to mirror from each end for periodicity. | ||
|
|
||
| Returns: | ||
| Copy of array with NaNs filled where possible. | ||
| """ | ||
| n_lon = array.shape[-1] | ||
| left = array[..., -pad_width:] | ||
| right = array[..., :pad_width] | ||
| padded = np.concatenate([left, array, right], axis=-1) | ||
|
|
||
| orig_shape = padded.shape | ||
| flat = padded.reshape(-1, orig_shape[-1]).copy() | ||
|
|
||
| x = np.arange(flat.shape[-1]) | ||
| for i in range(flat.shape[0]): | ||
| row = flat[i] | ||
| mask = np.isnan(row) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This method is greedy, where does this fall in the call stack relative to where we actually load the tensors? I guess not actually wrong to load data earlier, but a bit confusing in terms of the data model |
||
| if mask.any() and not mask.all(): | ||
| valid = ~mask | ||
| flat[i, mask] = np.interp(x[mask], x[valid], row[valid]) | ||
|
|
||
| result = flat.reshape(orig_shape) | ||
| return result[..., pad_width : pad_width + n_lon] | ||
|
|
||
|
|
||
| @dataclasses.dataclass | ||
| class FillNaNsConfig: | ||
| """ | ||
|
|
@@ -153,10 +186,14 @@ class FillNaNsConfig: | |
| Parameters: | ||
| method: Type of fill operation. Currently only 'constant' is supported. | ||
| value: Value to fill NaNs with. | ||
| zonal_interp_variables: Variables to fill via periodic zonal (longitude) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If this is getting merged it seems like we'd want to add 'zonal_interp` as a method in addition to 'constant' rather than adding a new arg |
||
| interpolation before applying the constant fill. Only valid for | ||
| latlon spatial dimensions. | ||
| """ | ||
|
|
||
| method: Literal["constant"] = "constant" | ||
| value: float = 0.0 | ||
| zonal_interp_variables: list[str] = dataclasses.field(default_factory=list) | ||
|
|
||
|
|
||
| def load_series_data_zarr_async( | ||
|
|
@@ -173,6 +210,9 @@ def load_series_data_zarr_async( | |
| selection = (time_slice, *nontime_selection) | ||
| loaded = _load_all_variables_zarr_async(path, names, selection) | ||
| if fill_nans is not None: | ||
| for k in fill_nans.zonal_interp_variables: | ||
| if k in loaded: | ||
| loaded[k] = _zonal_interp_periodic(loaded[k]) | ||
| for k, v in loaded.items(): | ||
| loaded[k] = np.nan_to_num(v, nan=fill_nans.value) | ||
| arrays = {} | ||
|
|
@@ -195,6 +235,11 @@ def load_series_data( | |
| # Fill NaNs after subsetting time slice to avoid triggering loading all | ||
| # data, since we do not use dask. | ||
| if fill_nans is not None: | ||
| for k in fill_nans.zonal_interp_variables: | ||
| if k in loaded: | ||
| loaded[k] = loaded[k].copy( | ||
| data=_zonal_interp_periodic(loaded[k].values) | ||
| ) | ||
| loaded = loaded.fillna(fill_nans.value) | ||
| arrays = {} | ||
| for n in names: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit use a different name than "orig" here since it has padding on it