Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions fme/ace/inference/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@ def run_evaluator_from_config(config: InferenceEvaluatorConfig):

stepper = config.load_stepper()
stepper.set_eval()
stepper.backfill_deptho(data.dataset_info.vertical_coordinate)

if not config.allow_incompatible_dataset:
try:
Expand Down
1 change: 1 addition & 0 deletions fme/ace/inference/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@ def run_inference_from_config(config: InferenceConfig):
ocean_fraction_name=stepper.ocean_fraction_name,
label_override=config.labels,
)
stepper.backfill_deptho(data.dataset_info.vertical_coordinate)

if not config.allow_incompatible_dataset:
try:
Expand Down
41 changes: 41 additions & 0 deletions fme/ace/stepper/single_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,6 +856,47 @@ def config(self) -> StepperConfig:
def derive_func(self) -> Callable[[TensorMapping, TensorMapping], TensorDict]:
return self._derive_func

def update_vertical_coordinate(
self, vertical_coordinate: VerticalCoordinate
) -> None:
"""Replace the vertical coordinate and rebuild ``derive_func``.

This allows overriding the checkpoint-serialized vertical coordinate
(e.g. to pick up a newly-added ``deptho`` field from the dataset).
"""
self._dataset_info = self._dataset_info.update_vertical_coordinate(
vertical_coordinate
)
try:
self._derive_func = vertical_coordinate.build_derive_function(
self._dataset_info.timestep,
self._dataset_info.horizontal_coordinates,
)
except MissingDatasetInfo:
self._derive_func = vertical_coordinate.build_derive_function(
self._dataset_info.timestep
)

def backfill_deptho(self, dataset_vertical_coordinate: VerticalCoordinate) -> None:
"""Adopt ``deptho`` from the dataset if the checkpoint lacks it.

Delegates to :meth:`VerticalCoordinate.adopt_deptho` so that
only coordinate types that support ``deptho`` need to know
about it. If the checkpoint coordinate is unchanged, this is
a no-op.
"""
try:
ckpt_vc = self.training_dataset_info.vertical_coordinate
except MissingDatasetInfo:
return
updated = ckpt_vc.adopt_deptho(dataset_vertical_coordinate)
if updated is not ckpt_vc:
logging.info(
"Backfilling deptho from dataset into checkpoint's "
"ocean vertical coordinate"
)
self.update_vertical_coordinate(updated)

@property
def surface_temperature_name(self) -> str | None:
return self._step_obj.surface_temperature_name
Expand Down
64 changes: 64 additions & 0 deletions fme/ace/stepper/test_single_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -1854,6 +1854,70 @@ def test_get_serialized_stepper_vertical_coordinate():
assert isinstance(vertical_coordinate, VerticalCoordinate)


class TestBackfillDeptho:
def test_backfills_when_checkpoint_lacks_deptho(self):
stepper = _get_stepper(["a"], ["a"])
idepth = torch.arange(4, dtype=torch.float)
mask = torch.ones(3)
ckpt_vc = DepthCoordinate(idepth=idepth, mask=mask)
stepper.update_vertical_coordinate(ckpt_vc)
vc = stepper.training_dataset_info.vertical_coordinate
assert isinstance(vc, DepthCoordinate)
assert vc.deptho is None

deptho = torch.tensor(2.5)
dataset_vc = DepthCoordinate(idepth=idepth, mask=mask, deptho=deptho)
stepper.backfill_deptho(dataset_vc)

updated_vc = stepper.training_dataset_info.vertical_coordinate
assert isinstance(updated_vc, DepthCoordinate)
assert updated_vc.deptho is not None
torch.testing.assert_close(updated_vc.deptho, deptho)

def test_noop_when_checkpoint_already_has_deptho(self):
stepper = _get_stepper(["a"], ["a"])
idepth = torch.arange(4, dtype=torch.float)
mask = torch.ones(3)
deptho = torch.tensor(2.5)
ckpt_vc = DepthCoordinate(idepth=idepth, mask=mask, deptho=deptho)
stepper.update_vertical_coordinate(ckpt_vc)

dataset_vc = DepthCoordinate(
idepth=idepth, mask=mask, deptho=torch.tensor(999.0)
)
stepper.backfill_deptho(dataset_vc)

updated_vc = stepper.training_dataset_info.vertical_coordinate
assert isinstance(updated_vc, DepthCoordinate)
torch.testing.assert_close(updated_vc.deptho, deptho)

def test_noop_when_dataset_lacks_deptho(self):
stepper = _get_stepper(["a"], ["a"])
idepth = torch.arange(4, dtype=torch.float)
mask = torch.ones(3)
ckpt_vc = DepthCoordinate(idepth=idepth, mask=mask)
stepper.update_vertical_coordinate(ckpt_vc)

dataset_vc = DepthCoordinate(idepth=idepth, mask=mask)
stepper.backfill_deptho(dataset_vc)

vc = stepper.training_dataset_info.vertical_coordinate
assert isinstance(vc, DepthCoordinate)
assert vc.deptho is None

def test_noop_for_non_depth_coordinate(self):
stepper = _get_stepper(["a"], ["a"])
original_vc = stepper.training_dataset_info.vertical_coordinate
assert isinstance(original_vc, HybridSigmaPressureCoordinate)

other_vc = HybridSigmaPressureCoordinate(
ak=torch.arange(4, dtype=torch.float),
bk=torch.arange(4, dtype=torch.float),
)
stepper.backfill_deptho(other_vc)
assert stepper.training_dataset_info.vertical_coordinate == original_vc


def _get_stepper_with_input_masking(dataset_info_has_mask_provider: bool = True):
# basic StepperConfig with input_masking configured
config = StepperConfig(
Expand Down
25 changes: 25 additions & 0 deletions fme/core/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,14 @@ def build_derive_function(
) -> DeriveFnABC:
pass

def adopt_deptho(self, other: "VerticalCoordinate") -> "VerticalCoordinate":
"""Return a coordinate updated with ``deptho`` from *other*, if applicable.

The default returns *self* unchanged. Subclasses that support
``deptho`` (e.g. :class:`DepthCoordinate`) override this.
"""
return self

@property
@abc.abstractmethod
def coords(self) -> dict[str, np.ndarray]:
Expand Down Expand Up @@ -368,6 +376,23 @@ def device(self) -> str:
def coords(self) -> dict[str, np.ndarray]:
return {"idepth": self.idepth.cpu().numpy()}

def with_deptho(self, deptho: torch.Tensor) -> "DepthCoordinate":
"""Return a new DepthCoordinate with the given ``deptho`` attached."""
return DepthCoordinate(
idepth=self.idepth,
mask=self.mask,
deptho=deptho,
)

def adopt_deptho(self, other: "VerticalCoordinate") -> "DepthCoordinate":
if (
self.deptho is None
and isinstance(other, DepthCoordinate)
and other.deptho is not None
):
return self.with_deptho(other.deptho)
return self

def to(self, device: str) -> "DepthCoordinate":
return DepthCoordinate(
idepth=self.idepth.to(device),
Expand Down
17 changes: 17 additions & 0 deletions fme/core/dataset_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,23 @@ def update_variable_metadata(
all_labels=self._all_labels,
)

def update_vertical_coordinate(
self, vertical_coordinate: VerticalCoordinate
) -> "DatasetInfo":
"""
Return a new DatasetInfo with the vertical coordinate replaced.
"""
return DatasetInfo(
horizontal_coordinates=self._horizontal_coordinates,
vertical_coordinate=vertical_coordinate,
mask_provider=self._mask_provider,
timestep=self._timestep,
variable_metadata=self._variable_metadata,
gridded_operations=self._gridded_operations,
img_shape=self._img_shape,
all_labels=self._all_labels,
)

def get_state(self) -> dict[str, Any]:
if self._gridded_operations is not None:
gridded_operations = self._gridded_operations.get_state()
Expand Down
50 changes: 50 additions & 0 deletions fme/core/test_coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,56 @@ def test_depth_integral_gradient_with_mask():
assert torch.all(torch.isfinite(integrand.grad))


def test_with_deptho():
idepth = torch.tensor([0.0, 10.0, 50.0])
mask = torch.ones(2, 3, 2)
coord = DepthCoordinate(idepth=idepth, mask=mask)
assert coord.deptho is None

deptho = torch.full((2, 3), 40.0)
updated = coord.with_deptho(deptho)
assert updated.deptho is not None
torch.testing.assert_close(updated.deptho, deptho)
torch.testing.assert_close(updated.idepth, idepth)
torch.testing.assert_close(updated.mask, mask)
# original unchanged
assert coord.deptho is None


class TestAdoptDeptho:
def test_depth_coordinate_adopts_when_missing(self):
idepth = torch.tensor([0.0, 10.0, 50.0])
mask = torch.ones(2, 3, 2)
coord = DepthCoordinate(idepth=idepth, mask=mask)
deptho = torch.full((2, 3), 40.0)
other = DepthCoordinate(idepth=idepth, mask=mask, deptho=deptho)

updated = coord.adopt_deptho(other)
assert updated is not coord
assert updated.deptho is not None
torch.testing.assert_close(updated.deptho, deptho)

def test_depth_coordinate_noop_when_already_present(self):
idepth = torch.tensor([0.0, 10.0, 50.0])
mask = torch.ones(2, 3, 2)
existing_deptho = torch.full((2, 3), 30.0)
coord = DepthCoordinate(idepth=idepth, mask=mask, deptho=existing_deptho)
other_deptho = torch.full((2, 3), 99.0)
other = DepthCoordinate(idepth=idepth, mask=mask, deptho=other_deptho)

result = coord.adopt_deptho(other)
assert result is coord

def test_depth_coordinate_noop_when_other_lacks_deptho(self):
idepth = torch.tensor([0.0, 10.0, 50.0])
mask = torch.ones(2, 3, 2)
coord = DepthCoordinate(idepth=idepth, mask=mask)
other = DepthCoordinate(idepth=idepth, mask=mask)

result = coord.adopt_deptho(other)
assert result is coord


@pytest.mark.skipif(e2ghpx is None, reason="earth2grid is not available")
@pytest.mark.parametrize("pad", [True, False])
def test_healpix_coordinates_xyz(pad: bool):
Expand Down
1 change: 1 addition & 0 deletions fme/coupled/inference/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ def run_evaluator_from_config(config: InferenceEvaluatorConfig):
dataset_info=stepper.training_dataset_info,
)
stepper.set_eval()
stepper.ocean.backfill_deptho(data.ocean_properties.vertical_coordinate)

aggregator_config: InferenceEvaluatorAggregatorConfig = config.aggregator
batch = next(iter(data.loader))
Expand Down
1 change: 1 addition & 0 deletions fme/coupled/inference/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ def run_inference_from_config(config: InferenceConfig):
initial_condition=initial_condition,
dataset_info=stepper.training_dataset_info,
)
stepper.ocean.backfill_deptho(data.ocean_properties.vertical_coordinate)

aggregator_config: InferenceAggregatorConfig = config.aggregator
variable_metadata = get_derived_variable_metadata() | data.variable_metadata
Expand Down
Loading