diff --git a/fredipy/kernels.py b/fredipy/kernels.py index e4892ac..3a2c87c 100644 --- a/fredipy/kernels.py +++ b/fredipy/kernels.py @@ -164,7 +164,7 @@ def empy(x, y): None grad = [empy for _ in range(self.dim)] grad[0] = lambda x, y: 2 / self.variance * self.make(x, y) for i in range(self.dim - 1): - grad[i + 1] = lambda x, y: self.make(x, y) * ( + grad[i + 1] = lambda x, y, i=i: self.make(x, y) * ( np.sum(make_column_vector(x[:, i])**2, 1)[:, None] + np.sum(make_column_vector(y[:, i])**2, 1) - 2 * make_column_vector(x[:, i]) @ make_row_vector(y[:, i]) diff --git a/tests/test_kernels.py b/tests/test_kernels.py index fb6e47e..a886dd8 100644 --- a/tests/test_kernels.py +++ b/tests/test_kernels.py @@ -289,3 +289,46 @@ def test_matern52_set_params( assert kernel.variance == params[0] assert kernel.lengthscale == params[1:] + + +@pytest.mark.parametrize("variance, lengthscale", [ + [2.0, 1.5], # 1D + [2.0, [0.5, 1.5]], # 2D -- catches late-binding closure bugs + ]) +def test_rbf_params_gradient( + variance: float, + lengthscale: float | list[float], + ) -> None: + """Check analytical gradient against finite differences for each hyperparameter.""" + kernel = kernels.RadialBasisFunction(variance=variance, lengthscale=lengthscale) + n_dim = len(kernel.lengthscale) + n_test = 5 + X = rng.randn(n_test, n_dim) + Y = rng.randn(n_test, n_dim) + + params = [kernel.variance] + list(kernel.lengthscale) + + # collect analytical gradients before mutating kernel state + grad = kernel.params_gradient() + analytical = [grad[idx](X, Y) for idx in range(kernel.dim)] + + eps = 1e-5 + for idx in range(kernel.dim): + # central finite difference + params_plus = list(params) + params_plus[idx] += eps + params_minus = list(params) + params_minus[idx] -= eps + + kernel.set_params(params_plus) + kernel._empty_cache() + K_plus = kernel.make(X, Y, cache=False) + kernel.set_params(params_minus) + kernel._empty_cache() + K_minus = kernel.make(X, Y, cache=False) + + numerical = (K_plus - K_minus) / (2 * eps) + assert_allclose(analytical[idx], numerical, rtol=1e-4) + + # restore + kernel.set_params(params)