Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
62077b9
claude first attempt to add per variable loss to TrainAgg
Arcomano1234 Mar 16, 2026
1d1aded
make per channel loss configurable and clean up code
Arcomano1234 Mar 17, 2026
7079529
test era5 config
Arcomano1234 Mar 17, 2026
ef1b9f6
set per channel loss in TrainAgg to False
Arcomano1234 Mar 17, 2026
87af429
update tests
Arcomano1234 Mar 17, 2026
6223e0e
incorporate comments
Arcomano1234 Mar 18, 2026
35f2f1d
fix doc
Arcomano1234 Mar 18, 2026
84211b1
remove coupling of train agg and generic trainer
Arcomano1234 Mar 18, 2026
7def3ca
Merge branch 'main' into feature/per-channel-loss-train-agg
Arcomano1234 Mar 18, 2026
a67c993
revert changes to ERA5 configs
Arcomano1234 Mar 18, 2026
377fdc0
Merge branch 'feature/per-channel-loss-train-agg' of github.com:ai2cm…
Arcomano1234 Mar 18, 2026
61811a0
claude attempt for breaking things up to PRs
Arcomano1234 Mar 20, 2026
b310017
claude attempt for breaking things up to PRs
Arcomano1234 Mar 20, 2026
022dc70
claude losses return 1d vector over channel dim
Arcomano1234 Mar 23, 2026
c608a9b
claude refactor to no reduce losses
Arcomano1234 Mar 24, 2026
3ad2b01
Merge branch 'main' into feature/losses-return-1d-loss-vector
Arcomano1234 Mar 26, 2026
1f75eba
make loss reduction an argument
Arcomano1234 Mar 30, 2026
373d6ec
move reduction arg to StepLoss forward
Arcomano1234 Mar 30, 2026
f64dc68
address naming comment
Arcomano1234 Mar 30, 2026
d236f93
Merge branch 'main' into feature/per-channel-loss-train-agg
Arcomano1234 Mar 30, 2026
a78a4a8
Merge branch 'feature/losses-return-1d-loss-vector' into feature/per-…
Arcomano1234 Mar 30, 2026
72c0e19
Add newly created regression files
Arcomano1234 Mar 30, 2026
f82f599
clean up loss and tests to remove unused code after the loss-return-1…
Arcomano1234 Mar 30, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions fme/ace/aggregator/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,93 @@ def test_aggregator_gets_logs_with_no_batches(config: TrainAggregatorConfig):
logs = agg.get_logs(label="test")
assert np.isnan(logs.pop("test/mean/loss"))
assert logs == {}


def test_aggregator_logs_per_channel_loss():
"""
Per-channel (per-variable) loss is accumulated from batch.metrics and reported.
"""
batch_size = 4
n_ensemble = 1
n_time = 2
nx, ny = 2, 2
device = get_device()
gridded_operations = LatLonOperations(
area_weights=torch.ones(nx, ny, device=device)
)
config = TrainAggregatorConfig(
spherical_power_spectrum=False, weighted_rmse=False, per_channel_loss=True
)
agg = TrainAggregator(config=config, operations=gridded_operations)
target_data = EnsembleTensorDict(
{"a": torch.randn(batch_size, 1, n_time, nx, ny, device=device)},
)
gen_data = EnsembleTensorDict(
{"a": torch.randn(batch_size, n_ensemble, n_time, nx, ny, device=device)},
)
agg.record_batch(
batch=TrainOutput(
metrics={
"loss": torch.tensor(1.0, device=device),
"loss/a": torch.tensor(0.5, device=device),
},
target_data=target_data,
gen_data=gen_data,
time=xr.DataArray(np.zeros((batch_size, n_time)), dims=["sample", "time"]),
normalize=lambda x: x,
),
)
agg.record_batch(
batch=TrainOutput(
metrics={
"loss": torch.tensor(2.0, device=device),
"loss/a": torch.tensor(1.0, device=device),
},
target_data=target_data,
gen_data=gen_data,
time=xr.DataArray(np.zeros((batch_size, n_time)), dims=["sample", "time"]),
normalize=lambda x: x,
),
)
logs = agg.get_logs(label="train")
assert logs["train/mean/loss"] == 1.5
assert logs["train/mean/loss/a"] == 0.75


def test_aggregator_per_channel_loss_disabled():
"""When per_channel_loss=False, get_logs does not include per-variable loss."""
batch_size = 4
n_ensemble = 1
n_time = 2
nx, ny = 2, 2
device = get_device()
gridded_operations = LatLonOperations(
area_weights=torch.ones(nx, ny, device=device)
)
config = TrainAggregatorConfig(
spherical_power_spectrum=False,
weighted_rmse=False,
per_channel_loss=False,
)
agg = TrainAggregator(config=config, operations=gridded_operations)
target_data = EnsembleTensorDict(
{"a": torch.randn(batch_size, 1, n_time, nx, ny, device=device)},
)
gen_data = EnsembleTensorDict(
{"a": torch.randn(batch_size, n_ensemble, n_time, nx, ny, device=device)},
)
agg.record_batch(
batch=TrainOutput(
metrics={
"loss": torch.tensor(1.0, device=device),
"loss/a": torch.tensor(0.5, device=device),
},
target_data=target_data,
gen_data=gen_data,
time=xr.DataArray(np.zeros((batch_size, n_time)), dims=["sample", "time"]),
normalize=lambda x: x,
),
)
logs = agg.get_logs(label="train")
assert logs["train/mean/loss"] == 1.0
assert "train/mean/loss/a" not in logs
23 changes: 23 additions & 0 deletions fme/ace/aggregator/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
from fme.core.tensors import fold_ensemble_dim, fold_sized_ensemble_dim
from fme.core.typing_ import TensorMapping

# Metric key prefix for per-variable loss (must match stepper's metrics["loss/<var>"]).
PER_CHANNEL_LOSS_PREFIX = "loss/"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not ideal to have this coupling with the naming in the stepper metrics, but this already exists for the the other loss terms so I think it's okay.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I was not a fan of this at all either but Claude and I couldn't think of a good way. I guess the one thing I can do is make this an aggregator it self and decouple anything from the stepper. This would also help reduce the need to record it when we aren't using it during training.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem becomes then getting the loss function to the aggregator which has it's own complications. I defer to you or Jeremy on whether its worth decoupling this from the stepper and just pass a loss_fn to an "PerChannelLossAggregator".

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, like I said it's a pre-existing issue so I don't think we should worry about decoupling in this PR. But open to other thoughts from @mcgibbon on this.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd suggest using a new attribute on the TrainOutput instead of a string label.



@dataclasses.dataclass
class TrainAggregatorConfig:
Expand All @@ -23,10 +26,13 @@ class TrainAggregatorConfig:
Attributes:
spherical_power_spectrum: Whether to compute the spherical power spectrum.
weighted_rmse: Whether to compute the weighted RMSE.
per_channel_loss: Whether to accumulate and report per-variable (per-channel)
loss in get_logs (e.g. train/mean/loss/<var_name>).
"""

spherical_power_spectrum: bool = True
weighted_rmse: bool = True
per_channel_loss: bool = True


class Aggregator(Protocol):
Expand All @@ -48,6 +54,8 @@ class TrainAggregator(AggregatorABC[TrainOutput]):
def __init__(self, config: TrainAggregatorConfig, operations: GriddedOperations):
self._n_loss_batches = 0
self._loss = torch.tensor(0.0, device=get_device())
self._per_channel_loss: dict[str, torch.Tensor] = {}
self._per_channel_loss_enabled = config.per_channel_loss
self._paired_aggregators: dict[str, Aggregator] = {}
if config.spherical_power_spectrum:
try:
Expand All @@ -73,6 +81,16 @@ def __init__(self, config: TrainAggregatorConfig, operations: GriddedOperations)
def record_batch(self, batch: TrainOutput):
self._loss += batch.metrics["loss"]
self._n_loss_batches += 1
if self._per_channel_loss_enabled:
for key, value in batch.metrics.items():
if not key.startswith(PER_CHANNEL_LOSS_PREFIX):
continue
var_name = key.removeprefix(PER_CHANNEL_LOSS_PREFIX)
acc = self._per_channel_loss.get(
var_name,
torch.tensor(0.0, device=get_device(), dtype=value.dtype),
)
self._per_channel_loss[var_name] = acc + value

folded_gen_data, n_ensemble = fold_ensemble_dim(batch.gen_data)
folded_target_data = fold_sized_ensemble_dim(batch.target_data, n_ensemble)
Expand Down Expand Up @@ -100,6 +118,11 @@ def get_logs(self, label: str) -> dict[str, torch.Tensor]:
logs[f"{label}/mean/loss"] = float(
dist.reduce_mean(self._loss / self._n_loss_batches).cpu().numpy()
)
if self._n_loss_batches > 0 and self._per_channel_loss_enabled:
for var_name, acc in self._per_channel_loss.items():
logs[f"{label}/mean/loss/{var_name}"] = float(
dist.reduce_mean(acc / self._n_loss_batches).cpu().numpy()
)
return logs

@torch.no_grad()
Expand Down
19 changes: 16 additions & 3 deletions fme/ace/stepper/single_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -1606,6 +1606,7 @@ def _accumulate_loss(
"data requirements are retrieved, so this is a bug."
)
n_forward_steps = stochastic_n_forward_steps
per_channel_sum: dict[str, torch.Tensor] | None = None
for step in range(n_forward_steps):
optimize_step = (
step == n_forward_steps - 1 or not self._config.optimize_last_step_only
Expand All @@ -1626,10 +1627,22 @@ def _accumulate_loss(
for k, v in target_data.data.items()
}
)
step_loss = self._loss_obj(gen_step, target_step, step=step)
metrics[f"loss_step_{step}"] = step_loss.detach()
step_loss = self._loss_obj(
gen_step, target_step, step=step, reduce=False
)
step_total_loss = step_loss.sum()
metrics[f"loss_step_{step}"] = step_total_loss.detach()
per_ch = self._loss_obj.loss.packer.unpack(step_loss.detach(), axis=0)
if per_channel_sum is None:
per_channel_sum = {k: v.clone() for k, v in per_ch.items()}
else:
for k in per_channel_sum:
per_channel_sum[k] = per_channel_sum[k] + per_ch[k]
if optimize_step:
optimization.accumulate_loss(step_loss)
optimization.accumulate_loss(step_total_loss)
if per_channel_sum is not None:
for k, v in per_channel_sum.items():
metrics[f"loss/{k}"] = v.detach()
return output_list

def update_training_history(self, training_job: TrainingJob) -> None:
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
5 changes: 4 additions & 1 deletion fme/core/generics/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,10 @@ def train_one_epoch(self):
stop_batch=self.config.train_evaluation_batches
):
with GlobalTimer():
stepped = self.stepper.train_on_batch(batch, self._no_optimization)
stepped = self.stepper.train_on_batch(
batch,
self._no_optimization,
)
aggregator.record_batch(stepped)
if (
self._should_save_checkpoints()
Expand Down
67 changes: 62 additions & 5 deletions fme/core/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,42 @@
from fme.core.typing_ import TensorMapping


def _channel_dim_positive(ndim: int, channel_dim: int) -> int:
return channel_dim if channel_dim >= 0 else ndim + channel_dim


def _uniform_broadcast_per_channel(
total: torch.Tensor, n_channels: int
) -> torch.Tensor:
"""Evenly split a scalar across ``n_channels`` so ``.sum() == total``."""
return (total / n_channels).expand(n_channels)


def _reduce_to_per_channel(
loss_value: torch.Tensor,
channel_dim: int,
n_channels: int,
) -> torch.Tensor:
"""Reduce any loss tensor to shape ``(n_channels,)``.

Handles element-wise tensors, partially-reduced tensors
(e.g. after area-weighted mean removes spatial dims), and
scalars (uniformly broadcast).

``channel_dim`` must be a **non-negative** index, computed from
the *input* tensor (before the loss potentially removes dims).

The result satisfies ``tensor.sum() == loss_value.mean()``
for element-wise losses.
"""
if loss_value.ndim == 0:
return _uniform_broadcast_per_channel(loss_value, n_channels)
dims = tuple(i for i in range(loss_value.ndim) if i != channel_dim)
if not dims:
return loss_value / n_channels
return loss_value.mean(dim=dims) / n_channels


class NaNLoss(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -59,7 +95,16 @@ def __call__(
self,
predict_dict: TensorMapping,
target_dict: TensorMapping,
):
reduce: bool = True,
) -> torch.Tensor:
"""
Args:
predict_dict: The predicted data.
target_dict: The target data.
reduce: If True (default), return a scalar loss.
If False, return a per-channel loss vector of
shape ``(n_channels,)``.
"""
predict_tensors = self.packer.pack(
self.normalizer.normalize(predict_dict), axis=self.channel_dim
)
Expand All @@ -71,7 +116,14 @@ def __call__(
predict_tensors = torch.where(nan_mask, 0.0, predict_tensors)
target_tensors = torch.where(nan_mask, 0.0, target_tensors)

return self.loss(predict_tensors, target_tensors)
result = self.loss(predict_tensors, target_tensors)
if not reduce:
cdim = _channel_dim_positive(predict_tensors.ndim, self.channel_dim)
n_c = int(predict_tensors.shape[cdim])
return _reduce_to_per_channel(result, cdim, n_c)
if result.ndim > 0:
return result.mean()
return result

def get_normalizer_state(self) -> dict[str, float]:
return self.normalizer.get_state()
Expand Down Expand Up @@ -457,18 +509,23 @@ def forward(
predict_dict: TensorMapping,
target_dict: TensorMapping,
step: int,
reduce: bool = True,
) -> torch.Tensor:
"""
Args:
predict_dict: The predicted data.
target_dict: The target data.
step: The step number, indexed from 0 for the first step.
reduce: If True (default), return a scalar loss.
If False, return a per-channel loss vector of
shape ``(n_channels,)``.

Returns:
The loss.
The loss, scalar or ``(n_channels,)`` depending on
``reduce``.
"""
step_weight = (1.0 + self.sqrt_loss_decay_constant * step) ** (-0.5)
return self.loss(predict_dict, target_dict) * step_weight
return self.loss(predict_dict, target_dict, reduce=reduce) * step_weight


@dataclasses.dataclass
Expand Down Expand Up @@ -525,7 +582,7 @@ def build(
channel_dim: int = -3,
) -> StepLoss:
loss = self.loss_config.build(
reduction="mean",
reduction="none",
gridded_operations=gridded_ops,
)
return StepLoss(
Expand Down
Loading
Loading