diff --git a/examples/04_gaussian/data/gen_mock_data.py b/examples/04_gaussian/data/gen_mock_data.py index 69a32cf..986e14e 100755 --- a/examples/04_gaussian/data/gen_mock_data.py +++ b/examples/04_gaussian/data/gen_mock_data.py @@ -11,7 +11,7 @@ import numpy as np # True parameters (no batch dimension) -z_true = np.array([-2.0, 0.0, 2.0]) +z_true = np.array([-5.0, 0.0, 5.0]) # Observation: x = exp(z), no noise (Asimov) x_obs = np.exp(z_true) diff --git a/examples/04_gaussian/data/mock_data.npz b/examples/04_gaussian/data/mock_data.npz index 1f9ed18..feb7d6b 100644 Binary files a/examples/04_gaussian/data/mock_data.npz and b/examples/04_gaussian/data/mock_data.npz differ diff --git a/falcon/estimators/gaussian_fullcov.py b/falcon/estimators/gaussian_fullcov.py index a114b13..945930c 100644 --- a/falcon/estimators/gaussian_fullcov.py +++ b/falcon/estimators/gaussian_fullcov.py @@ -103,9 +103,9 @@ def __init__( # MLP for mean prediction self.net = build_mlp(condition_dim, hidden_dim, param_dim, num_layers) - # Input statistics (conditions) - diagonal whitening - self.register_buffer("_input_mean", torch.zeros(condition_dim)) - self.register_buffer("_input_std", torch.ones(condition_dim)) + # Input statistics (conditions) - diagonal whitening; float64 for precision + self.register_buffer("_input_mean", torch.zeros(condition_dim, dtype=torch.float64)) + self.register_buffer("_input_std", torch.ones(condition_dim, dtype=torch.float64)) # Output statistics (theta) - diagonal whitening # Always float64 for precision; results are cast to input dtype on output. @@ -123,7 +123,8 @@ def to(self, *args, **kwargs): """Move module, preserving parameter-space buffer dtype.""" param_dtype = self._output_mean.dtype result = super().to(*args, **kwargs) - for name in ('_output_mean', '_output_std', '_residual_cov', + for name in ('_input_mean', '_input_std', + '_output_mean', '_output_std', '_residual_cov', '_residual_eigvals', '_residual_eigvecs'): setattr(result, name, getattr(result, name).to(param_dtype)) return result @@ -189,20 +190,21 @@ def sample(self, conditions: torch.Tensor, gamma: Optional[float] = None) -> tor def _forward_mean(self, conditions: torch.Tensor) -> torch.Tensor: """Predict mean using diagonal whitening. - Conditions are cast to MLP dtype for input whitening. MLP output - is cast to parameter-space dtype before de-whitening. + Whitening is done in float64 (statistics precision). The whitened value + is cast to the MLP's dtype (float32) before the forward pass, then the + MLP output is upcast to parameter-space dtype (float64) before de-whitening. """ c = conditions.to(self._input_mean.dtype) x_norm = (c - self._input_mean.detach()) / self._input_std.detach() - r = self.net(x_norm) + r = self.net(x_norm.to(next(self.net.parameters()).dtype)) r = r.to(self._output_mean.dtype) return self._output_mean.detach() + self._output_std.detach() * r def _update_stats(self, theta: torch.Tensor, conditions: torch.Tensor) -> None: """Update running statistics using EMA. - Output buffers are float64 for precision. Input buffers stay in MLP - dtype (float32); conditions are cast accordingly. + Both input and output buffers are float64; conditions and theta are cast + accordingly before computing statistics. """ m = self.momentum with torch.no_grad():