Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/04_gaussian/data/gen_mock_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import numpy as np

# True parameters (no batch dimension)
z_true = np.array([-2.0, 0.0, 2.0])
z_true = np.array([-5.0, 0.0, 5.0])

# Observation: x = exp(z), no noise (Asimov)
x_obs = np.exp(z_true)
Expand Down
Binary file modified examples/04_gaussian/data/mock_data.npz
Binary file not shown.
20 changes: 11 additions & 9 deletions falcon/estimators/gaussian_fullcov.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,9 @@ def __init__(
# MLP for mean prediction
self.net = build_mlp(condition_dim, hidden_dim, param_dim, num_layers)

# Input statistics (conditions) - diagonal whitening
self.register_buffer("_input_mean", torch.zeros(condition_dim))
self.register_buffer("_input_std", torch.ones(condition_dim))
# Input statistics (conditions) - diagonal whitening; float64 for precision
self.register_buffer("_input_mean", torch.zeros(condition_dim, dtype=torch.float64))
self.register_buffer("_input_std", torch.ones(condition_dim, dtype=torch.float64))

# Output statistics (theta) - diagonal whitening
# Always float64 for precision; results are cast to input dtype on output.
Expand All @@ -123,7 +123,8 @@ def to(self, *args, **kwargs):
"""Move module, preserving parameter-space buffer dtype."""
param_dtype = self._output_mean.dtype
result = super().to(*args, **kwargs)
for name in ('_output_mean', '_output_std', '_residual_cov',
for name in ('_input_mean', '_input_std',
'_output_mean', '_output_std', '_residual_cov',
'_residual_eigvals', '_residual_eigvecs'):
setattr(result, name, getattr(result, name).to(param_dtype))
return result
Expand Down Expand Up @@ -189,20 +190,21 @@ def sample(self, conditions: torch.Tensor, gamma: Optional[float] = None) -> tor
def _forward_mean(self, conditions: torch.Tensor) -> torch.Tensor:
"""Predict mean using diagonal whitening.

Conditions are cast to MLP dtype for input whitening. MLP output
is cast to parameter-space dtype before de-whitening.
Whitening is done in float64 (statistics precision). The whitened value
is cast to the MLP's dtype (float32) before the forward pass, then the
MLP output is upcast to parameter-space dtype (float64) before de-whitening.
"""
c = conditions.to(self._input_mean.dtype)
x_norm = (c - self._input_mean.detach()) / self._input_std.detach()
r = self.net(x_norm)
r = self.net(x_norm.to(next(self.net.parameters()).dtype))
r = r.to(self._output_mean.dtype)
return self._output_mean.detach() + self._output_std.detach() * r

def _update_stats(self, theta: torch.Tensor, conditions: torch.Tensor) -> None:
"""Update running statistics using EMA.

Output buffers are float64 for precision. Input buffers stay in MLP
dtype (float32); conditions are cast accordingly.
Both input and output buffers are float64; conditions and theta are cast
accordingly before computing statistics.
"""
m = self.momentum
with torch.no_grad():
Expand Down
Loading