diff --git a/fredipy/models.py b/fredipy/models.py index 7083ed4..12c04b6 100644 --- a/fredipy/models.py +++ b/fredipy/models.py @@ -8,7 +8,7 @@ import scipy as sp from .covariance import TwoSided, OneSided -from .util import make_column_vector +from .util import allclose, make_column_vector class Model: @@ -218,11 +218,12 @@ def _maybe_prepare_inference( self, w_pred: np.ndarray, ) -> None: - if not self._inference_cache: - OpKer = self.OpKer( - self.kernel, self.constraints, w_pred) - self._inference_cache = { - 'OpKer': OpKer} + if self._inference_cache and allclose(w_pred, self._inference_cache['w_pred']): + return + OpKer = self.OpKer( + self.kernel, self.constraints, w_pred) + self._inference_cache = { + 'OpKer': OpKer, 'w_pred': w_pred} class GP(GaussianProcess): diff --git a/tests/test_reconstruction.py b/tests/test_reconstruction.py index eddf820..317bacb 100644 --- a/tests/test_reconstruction.py +++ b/tests/test_reconstruction.py @@ -169,3 +169,39 @@ def test_dressing_1D() -> None: print(devs) assert all(i > 0 for i in devs), \ "Reconstructed data does not match input" + + +def test_predict_different_w_pred() -> None: + """predict() called twice with different grids should give correct results both times""" + w_pred_1 = np.arange(0.5, 5, 0.5) + w_pred_2 = np.arange(0.1, 10, 0.1) + p = np.linspace(0.1, 10, 30) + + a = 1.6 + m = 1 + g = 0.8 + + G = get_G(p, a, m, g) + err = 1e-5 + + data = { + 'x': p, + 'y': G + err * rng.randn(len(G)), + 'cov_y': err**2 * np.ones_like(p)} + + kernel = fp.kernels.RadialBasisFunction(0.5, 0.3) + integrator = fp.integrators.Riemann_1D(0, 10, 500) + integral_op = fp.operators.Integral(kl_kernel, integrator) + constraints = [fp.constraints.LinearEquality(integral_op, data)] + model = fp.models.GaussianProcess(kernel, constraints) + + rho1, err1 = model.predict(w_pred_1) + rho2, err2 = model.predict(w_pred_2) + + ref1 = get_rho(w_pred_1, a, m, g) + ref2 = get_rho(w_pred_2, a, m, g) + + devs1 = err1 - abs(rho1.flatten() - ref1) + devs2 = err2 - abs(rho2.flatten() - ref2) + assert all(i > 0 for i in devs1), "first predict wrong" + assert all(i > 0 for i in devs2), "second predict wrong"