From ba30c9c955259d4d617a955dd6c4e0c70184c0fc Mon Sep 17 00:00:00 2001 From: gaurav <721466+soodoku@users.noreply.github.com> Date: Sun, 17 Aug 2025 12:50:02 -0700 Subject: [PATCH] Guard against division by zero in LOOCV MSE --- src/nw_analytic_hessian.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/nw_analytic_hessian.py b/src/nw_analytic_hessian.py index 828c352..ee6bf86 100644 --- a/src/nw_analytic_hessian.py +++ b/src/nw_analytic_hessian.py @@ -42,15 +42,16 @@ def loocv_mse(x: np.ndarray, y: np.ndarray, h: float, kernel: str) -> Tuple[floa num = w @ y den = w.sum(axis=1) - m = num / den + den_safe = np.where(den == 0, np.finfo(float).eps, den) # Guard against division by zero + m = num / den_safe num1 = w1 @ y den1 = w1.sum(axis=1) - m1 = (num1 * den - num * den1) / (den ** 2) + m1 = (num1 * den_safe - num * den1) / (den_safe ** 2) num2 = w2 @ y den2 = w2.sum(axis=1) - m2 = (num2 * den - num * den2) / (den ** 2) - 2 * m1 * den1 / den + m2 = (num2 * den_safe - num * den2) / (den_safe ** 2) - 2 * m1 * den1 / den_safe resid = y - m loss = np.mean(resid**2)