From befbe8506f35b0db9dc298a3b1f38bba8b41e0bf Mon Sep 17 00:00:00 2001 From: Christoph Weniger Date: Sat, 6 Jun 2026 23:48:20 +0200 Subject: [PATCH 1/2] Fix GaussianPosterior input whitening buffers to float64 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit _input_mean and _input_std were initialised with the default dtype (float32), causing the float64 condition inputs to be silently downcast before whitening statistics were computed. This is inconsistent with the output-side buffers which are explicitly float64, and loses precision when the whitened value could still be computed precisely before the cast into the float32 MLP. Fix: initialise both buffers as float64, add them to the to() override so they survive .to(device) calls, and decouple the dtype cast — whitening now runs in float64, with an explicit cast to the MLP's dtype immediately before the net() call. Co-Authored-By: Claude Sonnet 4.6 --- falcon/estimators/gaussian_fullcov.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/falcon/estimators/gaussian_fullcov.py b/falcon/estimators/gaussian_fullcov.py index a114b13..945930c 100644 --- a/falcon/estimators/gaussian_fullcov.py +++ b/falcon/estimators/gaussian_fullcov.py @@ -103,9 +103,9 @@ def __init__( # MLP for mean prediction self.net = build_mlp(condition_dim, hidden_dim, param_dim, num_layers) - # Input statistics (conditions) - diagonal whitening - self.register_buffer("_input_mean", torch.zeros(condition_dim)) - self.register_buffer("_input_std", torch.ones(condition_dim)) + # Input statistics (conditions) - diagonal whitening; float64 for precision + self.register_buffer("_input_mean", torch.zeros(condition_dim, dtype=torch.float64)) + self.register_buffer("_input_std", torch.ones(condition_dim, dtype=torch.float64)) # Output statistics (theta) - diagonal whitening # Always float64 for precision; results are cast to input dtype on output. @@ -123,7 +123,8 @@ def to(self, *args, **kwargs): """Move module, preserving parameter-space buffer dtype.""" param_dtype = self._output_mean.dtype result = super().to(*args, **kwargs) - for name in ('_output_mean', '_output_std', '_residual_cov', + for name in ('_input_mean', '_input_std', + '_output_mean', '_output_std', '_residual_cov', '_residual_eigvals', '_residual_eigvecs'): setattr(result, name, getattr(result, name).to(param_dtype)) return result @@ -189,20 +190,21 @@ def sample(self, conditions: torch.Tensor, gamma: Optional[float] = None) -> tor def _forward_mean(self, conditions: torch.Tensor) -> torch.Tensor: """Predict mean using diagonal whitening. - Conditions are cast to MLP dtype for input whitening. MLP output - is cast to parameter-space dtype before de-whitening. + Whitening is done in float64 (statistics precision). The whitened value + is cast to the MLP's dtype (float32) before the forward pass, then the + MLP output is upcast to parameter-space dtype (float64) before de-whitening. """ c = conditions.to(self._input_mean.dtype) x_norm = (c - self._input_mean.detach()) / self._input_std.detach() - r = self.net(x_norm) + r = self.net(x_norm.to(next(self.net.parameters()).dtype)) r = r.to(self._output_mean.dtype) return self._output_mean.detach() + self._output_std.detach() * r def _update_stats(self, theta: torch.Tensor, conditions: torch.Tensor) -> None: """Update running statistics using EMA. - Output buffers are float64 for precision. Input buffers stay in MLP - dtype (float32); conditions are cast accordingly. + Both input and output buffers are float64; conditions and theta are cast + accordingly before computing statistics. """ m = self.momentum with torch.no_grad(): From 3fd4c737a71fbeec9e022fc5a417ef2426c3ca8f Mon Sep 17 00:00:00 2001 From: Christoph Weniger Date: Sun, 7 Jun 2026 00:02:39 +0200 Subject: [PATCH 2/2] Updates to target data 04_gaussian --- examples/04_gaussian/data/gen_mock_data.py | 2 +- examples/04_gaussian/data/mock_data.npz | Bin 280 -> 280 bytes 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/04_gaussian/data/gen_mock_data.py b/examples/04_gaussian/data/gen_mock_data.py index 69a32cf..986e14e 100755 --- a/examples/04_gaussian/data/gen_mock_data.py +++ b/examples/04_gaussian/data/gen_mock_data.py @@ -11,7 +11,7 @@ import numpy as np # True parameters (no batch dimension) -z_true = np.array([-2.0, 0.0, 2.0]) +z_true = np.array([-5.0, 0.0, 5.0]) # Observation: x = exp(z), no noise (Asimov) x_obs = np.exp(z_true) diff --git a/examples/04_gaussian/data/mock_data.npz b/examples/04_gaussian/data/mock_data.npz index 1f9ed18f6b54b1b509ff7845449a31db77872d9e..feb7d6bda43f64c0430cf620597aebcdcfc5c212 100644 GIT binary patch delta 82 zcmbQiG=oVhz?+#xgaHB+87>%J-!TKo0pW>Kb9Kssc^}x%thR^9f3WX2E|_W2o8%DS O&B!FejH++qV^aVFITR)U delta 82 zcmbQiG=oVhz?+#xmjMD48Md`m6#s{UiBfZQ&NS{{yejCRJw*P4{kpqXcBJ&lIs|w# NGU+m->YMo36adWz9R2_R