From 3728430a8b5448a8ff38b6070e974e2dd8d9dfb2 Mon Sep 17 00:00:00 2001 From: gaurav <721466+soodoku@users.noreply.github.com> Date: Sun, 17 Aug 2025 12:45:36 -0700 Subject: [PATCH] refactor: centralize Newton-Armijo optimisation --- src/kde_analytic_hessian.py | 34 +++-------------------- src/nw_analytic_hessian.py | 35 ++++------------------- src/optim.py | 55 +++++++++++++++++++++++++++++++++++++ 3 files changed, 64 insertions(+), 60 deletions(-) create mode 100644 src/optim.py diff --git a/src/kde_analytic_hessian.py b/src/kde_analytic_hessian.py index 1355ef1..aae5867 100644 --- a/src/kde_analytic_hessian.py +++ b/src/kde_analytic_hessian.py @@ -1,11 +1,12 @@ -"""Newton–Armijo bandwidth selection for univariate KDE.""" +"""Newton–Armijo bandwidth selection for univariate KDE using :func:`optim.newton_armijo`.""" import argparse -from typing import Callable, Tuple +from typing import Tuple import numpy as np from derivatives import KERNELS +from optim import newton_armijo def lscv_generic(x: np.ndarray, h: float, kernel: str) -> Tuple[float, float, float]: @@ -47,33 +48,6 @@ def lscv_generic(x: np.ndarray, h: float, kernel: str) -> Tuple[float, float, fl return score, grad, hess -def newton_armijo( - x: np.ndarray, - h0: float, - kernel: str = "gauss", - tol: float = 1e-5, - max_iter: int = 12, -) -> Tuple[float, int]: - """Run Newton–Armijo to minimise LSCV and return (h_opt, evaluations).""" - h = float(h0) - evals = 0 - for _ in range(max_iter): - f, g, H = lscv_generic(x, h, kernel) - evals += 1 - if abs(g) < tol: - break - step = -g / H if (H > 0 and np.isfinite(H)) else -0.25 * g - if abs(step) / h < 1e-3: - break - for _ in range(10): - h_new = max(h + step, 1e-6) - if lscv_generic(x, h_new, kernel)[0] < f: - h = h_new - break - step *= 0.5 - return h, evals - - def main() -> None: parser = argparse.ArgumentParser(description="Analytic-Hessian KDE bandwidth selection") parser.add_argument("data", nargs="?", help="Path to 1D data (one value per line)") @@ -85,7 +59,7 @@ def main() -> None: x = np.loadtxt(args.data, ndmin=1) else: x = np.random.randn(200) - h, evals = newton_armijo(x, args.h0, kernel=args.kernel) + h, evals = newton_armijo(lscv_generic, x, args.h0, kernel=args.kernel) print(f"Optimal h={h:.5f} after {evals} evaluations") diff --git a/src/nw_analytic_hessian.py b/src/nw_analytic_hessian.py index 5209241..828c352 100644 --- a/src/nw_analytic_hessian.py +++ b/src/nw_analytic_hessian.py @@ -1,10 +1,12 @@ -"""Newton–Armijo bandwidth selection for Nadaraya–Watson regression.""" +"""Newton–Armijo bandwidth selection for Nadaraya–Watson regression using :func:`optim.newton_armijo`.""" import argparse from typing import Tuple import numpy as np +from optim import newton_armijo + SQRT_2PI = np.sqrt(2 * np.pi) @@ -57,34 +59,6 @@ def loocv_mse(x: np.ndarray, y: np.ndarray, h: float, kernel: str) -> Tuple[floa return loss, grad, hess -def newton_armijo( - x: np.ndarray, - y: np.ndarray, - h0: float, - kernel: str = "gauss", - tol: float = 1e-5, - max_iter: int = 12, -) -> Tuple[float, int]: - """Run Newton–Armijo to minimise LOOCV MSE.""" - h = float(h0) - evals = 0 - for _ in range(max_iter): - f, g, H = loocv_mse(x, y, h, kernel) - evals += 1 - if abs(g) < tol: - break - step = -g / H if (H > 0 and np.isfinite(H)) else -0.25 * g - if abs(step) / h < 1e-3: - break - for _ in range(10): - h_new = max(h + step, 1e-6) - if loocv_mse(x, y, h_new, kernel)[0] < f: - h = h_new - break - step *= 0.5 - return h, evals - - def main() -> None: parser = argparse.ArgumentParser(description="Analytic-Hessian NW bandwidth selection") parser.add_argument("data", nargs="?", help="Path to data with two columns x,y") @@ -98,7 +72,8 @@ def main() -> None: else: x = np.linspace(-2, 2, 200) y = np.sin(x) + 0.1 * np.random.randn(len(x)) - h, evals = newton_armijo(x, y, args.h0, kernel=args.kernel) + objective = lambda x_, h, k: loocv_mse(x_, y, h, k) + h, evals = newton_armijo(objective, x, args.h0, kernel=args.kernel) print(f"Optimal h={h:.5f} after {evals} evaluations") diff --git a/src/optim.py b/src/optim.py new file mode 100644 index 0000000..452e32e --- /dev/null +++ b/src/optim.py @@ -0,0 +1,55 @@ +"""Optimisation utilities.""" + +from typing import Callable, Tuple + +import numpy as np + + +def newton_armijo( + objective: Callable[[np.ndarray, float, str], Tuple[float, float, float]], + x: np.ndarray, + h0: float, + kernel: str = "gauss", + tol: float = 1e-5, + max_iter: int = 12, +) -> Tuple[float, int]: + """Run Newton–Armijo iterations for a generic objective. + + Parameters + ---------- + objective: + Callable returning ``(score, grad, hess)`` for given ``(x, h, kernel)``. + x: + Sample locations passed to ``objective``. + h0: + Initial bandwidth guess. + kernel: + Kernel name forwarded to ``objective``. + tol: + Tolerance for gradient magnitude to stop optimisation. + max_iter: + Maximum number of Newton updates. + + Returns + ------- + Tuple[float, int] + Optimised bandwidth and number of objective evaluations. + """ + + h = float(h0) + evals = 0 + for _ in range(max_iter): + f, g, H = objective(x, h, kernel) + evals += 1 + if abs(g) < tol: + break + step = -g / H if (H > 0 and np.isfinite(H)) else -0.25 * g + if abs(step) / h < 1e-3: + break + for _ in range(10): + h_new = max(h + step, 1e-6) + if objective(x, h_new, kernel)[0] < f: + h = h_new + break + step *= 0.5 + return h, evals