diff --git a/src/nw_analytic_hessian.py b/src/nw_analytic_hessian.py index 828c352..ee6bf86 100644 --- a/src/nw_analytic_hessian.py +++ b/src/nw_analytic_hessian.py @@ -42,15 +42,16 @@ def loocv_mse(x: np.ndarray, y: np.ndarray, h: float, kernel: str) -> Tuple[floa num = w @ y den = w.sum(axis=1) - m = num / den + den_safe = np.where(den == 0, np.finfo(float).eps, den) # Guard against division by zero + m = num / den_safe num1 = w1 @ y den1 = w1.sum(axis=1) - m1 = (num1 * den - num * den1) / (den ** 2) + m1 = (num1 * den_safe - num * den1) / (den_safe ** 2) num2 = w2 @ y den2 = w2.sum(axis=1) - m2 = (num2 * den - num * den2) / (den ** 2) - 2 * m1 * den1 / den + m2 = (num2 * den_safe - num * den2) / (den_safe ** 2) - 2 * m1 * den1 / den_safe resid = y - m loss = np.mean(resid**2)