diff --git a/src/humancompatible/train/dual_optim/nonopt/__init__.py b/src/humancompatible/train/dual_optim/nonopt/__init__.py new file mode 100644 index 0000000..989bfd0 --- /dev/null +++ b/src/humancompatible/train/dual_optim/nonopt/__init__.py @@ -0,0 +1,12 @@ +""" +PyTorch port of NonOpt (https://frankecurtis.github.io/NonOpt/), a solver for +unconstrained, locally Lipschitz (possibly nonconvex, nonsmooth) minimization +by Frank E. Curtis and collaborators (Curtis & Zebiane, arXiv:2503.22826). +""" + +from .direction import CuttingPlane, GradientCombination, GradientDirection +from .inverse_hessian import DenseInverseHessian, LimitedMemoryInverseHessian +from .line_search import backtracking, weak_wolfe +from .optimizer import NonOpt +from .point_set import Point, PointSet +from .qp import project_onto_simplex, solve_simplex_qp diff --git a/src/humancompatible/train/dual_optim/nonopt/direction.py b/src/humancompatible/train/dual_optim/nonopt/direction.py new file mode 100644 index 0000000..2a7c09c --- /dev/null +++ b/src/humancompatible/train/dual_optim/nonopt/direction.py @@ -0,0 +1,421 @@ +""" +Direction computation strategies, ports of ``NonOptDirectionComputationGradient``, +``NonOptDirectionComputationCuttingPlane`` (proximal-bundle, the NonOpt default) +and ``NonOptDirectionComputationGradientCombination`` (gradient-sampling) from +https://github.com/frankecurtis/NonOpt. + +Each strategy assembles a bundle of (sub)gradients ``G`` and cutting-plane +values ``c``, solves the dual subproblem + +.. math:: + \\min_{\\omega \\in \\Delta} \\tfrac{1}{2} \\omega^T (G^T W G)\\, \\omega + - (c - f_k \\mathbf{1})^T \\omega, + +and returns the search direction :math:`d = -W G \\omega` together with the +quantities the termination test and line search need. Unlike the C++ +implementation, no trust-region constraint is imposed on the subproblem (the +C++ default trust-region radius of ``1e+10 * ||g||`` makes it inactive in +practice anyway). +""" + +from dataclasses import dataclass +from typing import Callable, Tuple + +import torch +from torch import Tensor + +from .inverse_hessian import InverseHessian +from .point_set import Point, PointSet +from .qp import solve_simplex_qp + + +@dataclass +class DirectionResult: + """Search direction and the associated subproblem quantities.""" + + direction: Tensor + omega: Tensor + combination: Tensor #: convex combination of bundle gradients, ``G @ omega`` + dual_quadratic_value: float #: ``(G w)' W (G w)`` + direction_norm_inf: float + direction_norm2_squared: float + combination_norm_inf: float + combination_norm2_squared: float + radii_update_triggered: bool = False + + @property + def decrease_reference(self) -> float: + """Reference value for sufficient-decrease tests: + ``min(dual quadratic, max(||G w||^2, ||d||^2))``.""" + return min( + self.dual_quadratic_value, + max(self.combination_norm2_squared, self.direction_norm2_squared), + ) + + +def _solve_bundle_subproblem( + gradients: list, + cut_values: list, + f_current: float, + inverse_hessian: InverseHessian, +) -> DirectionResult: + """Solves the dual subproblem for the given bundle and recovers the primal + direction.""" + G = torch.stack(gradients, dim=1) + WG = inverse_hessian.apply_matrix(G) + Q = (G.t() @ WG).double() + b = torch.tensor( + [value - f_current for value in cut_values], dtype=torch.float64, device=G.device + ) + omega = solve_simplex_qp(Q, b).to(G.dtype) + direction = -(WG @ omega) + combination = G @ omega + return DirectionResult( + direction=direction, + omega=omega, + combination=combination, + dual_quadratic_value=float(omega.double() @ (Q @ omega.double())), + direction_norm_inf=float(direction.abs().max()), + direction_norm2_squared=float(direction.dot(direction)), + combination_norm_inf=float(combination.abs().max()), + combination_norm2_squared=float(combination.dot(combination)), + ) + + +class DirectionComputation: + """ + Base class for direction computation strategies. + + :param step_acceptance_tolerance: Tolerance for the sufficient-decrease test + that terminates the inner loop early. + :type step_acceptance_tolerance: float + :param downshift_constant: Cutting-plane downshift constant; the linear term + of an added cut is the minimum of its linearization value and the current + objective minus this constant times the squared distance to the iterate. + :type downshift_constant: float + :param inner_iteration_limit: Limit on inner (re-solve) iterations. + :type inner_iteration_limit: int + :param try_gradient_step: Whether to first try a cheap steepest-descent-like + step before assembling the full bundle. + :type try_gradient_step: bool + :param gradient_stepsize: Stepsize used by the tentative gradient step. + :type gradient_stepsize: float + :param try_shortened_step: Whether to also evaluate a shortened step in each + inner iteration (enriches the bundle near the iterate). + :type try_shortened_step: bool + :param shortened_stepsize: Stepsize factor for the shortened step. + :type shortened_stepsize: float + :param add_far_points: Whether to add trial points lying outside the + stationarity radius to the bundle. + :type add_far_points: bool + """ + + def __init__( + self, + step_acceptance_tolerance: float = 1e-08, + downshift_constant: float = 1e-01, + inner_iteration_limit: int = 2, + try_gradient_step: bool = True, + gradient_stepsize: float = 1e-03, + try_shortened_step: bool = True, + shortened_stepsize: float = 1e-03, + add_far_points: bool = False, + ) -> None: + self.step_acceptance_tolerance = step_acceptance_tolerance + self.downshift_constant = downshift_constant + self.inner_iteration_limit = inner_iteration_limit + self.try_gradient_step = try_gradient_step + self.gradient_stepsize = gradient_stepsize + self.try_shortened_step = try_shortened_step + self.shortened_stepsize = shortened_stepsize + self.add_far_points = add_far_points + + def compute( + self, + evaluate: Callable[[Tensor], Tuple[float, Tensor]], + x: Tensor, + f: float, + g: Tensor, + point_set: PointSet, + inverse_hessian: InverseHessian, + stationarity_radius: float, + radii_update_check: Callable[[DirectionResult], bool], + ) -> DirectionResult: + """ + Computes a search direction at the current iterate. + + :param evaluate: Callable mapping a flat iterate to ``(objective, gradient)``. + :type evaluate: Callable + :param x: Current iterate (flat). + :type x: torch.Tensor + :param f: Objective value at ``x``. + :type f: float + :param g: (Sub)gradient at ``x``. + :type g: torch.Tensor + :param point_set: Point set; may be augmented with trial points. + :type point_set: PointSet + :param inverse_hessian: Inverse Hessian approximation. + :type inverse_hessian: InverseHessian + :param stationarity_radius: Current stationarity radius. + :type stationarity_radius: float + :param radii_update_check: Predicate implementing the termination + strategy's radii-update test for a candidate direction. + :type radii_update_check: Callable + :return: Search direction and subproblem quantities. + :rtype: DirectionResult + """ + raise NotImplementedError + + # -- shared building blocks ------------------------------------------------- + + def _cut_value(self, x: Tensor, f: float, point: Point, linearize: bool) -> float: + """Cutting-plane linear term for a bundle point: the minimum of the + linearization value (if used by the strategy) and the downshifted value.""" + difference = x - point.x + downshift = f - self.downshift_constant * float(difference.dot(difference)) + if not linearize: + return downshift + linearization = point.f + float(point.g.dot(difference)) + return min(linearization, downshift) + + def _try_gradient_step(self, evaluate, x, f, g, inverse_hessian, radii_update_check): + """Tentative steepest-descent-like step; returns an accepted result or None.""" + result = _solve_bundle_subproblem([g], [f], f, inverse_hessian) + x_trial = x + self.gradient_stepsize * result.direction + f_trial, _ = evaluate(x_trial) + radii_flag = radii_update_check(result) + accepted = ( + f_trial - f + < -self.step_acceptance_tolerance * self.gradient_stepsize * result.decrease_reference + ) + if accepted or radii_flag: + result.radii_update_triggered = radii_flag + return result + return None + + def _inner_loop( + self, + evaluate, + x, + f, + point_set, + inverse_hessian, + stationarity_radius, + radii_update_check, + gradients, + cut_values, + linearize, + sample=None, + ) -> DirectionResult: + """Inner loop: alternates between evaluating the trial point implied by + the current bundle and re-solving the subproblem with an enriched bundle.""" + result = _solve_bundle_subproblem(gradients, cut_values, f, inverse_hessian) + inner_iteration = 1 + + while True: + x_trial = x + result.direction + f_trial, g_trial = evaluate(x_trial) + radii_flag = radii_update_check(result) + if ( + f_trial - f + < -self.step_acceptance_tolerance * result.decrease_reference + ) or radii_flag: + result.radii_update_triggered = radii_flag + return result + + if inner_iteration > self.inner_iteration_limit: + return result + + # add the (full-step) trial point to the point set and bundle + if self.add_far_points or result.direction_norm_inf <= stationarity_radius: + trial_point = Point(x_trial, f_trial, g_trial) + point_set.add(trial_point) + gradients.append(g_trial) + cut_values.append(self._cut_value(x, f, trial_point, linearize)) + + # add a shortened-step point near the iterate + if self.try_shortened_step and result.direction_norm_inf > 0.0: + shortened = ( + self.shortened_stepsize + * min(stationarity_radius, result.direction_norm_inf) + / result.direction_norm_inf + ) + x_short = x + shortened * result.direction + f_short, g_short = evaluate(x_short) + radii_flag = radii_update_check(result) + if ( + f_short - f + < -self.step_acceptance_tolerance * shortened * result.decrease_reference + ) or radii_flag: + result.radii_update_triggered = radii_flag + return result + short_point = Point(x_short, f_short, g_short) + point_set.add(short_point) + gradients.append(g_short) + cut_values.append(self._cut_value(x, f, short_point, linearize)) + + if sample is not None: + sample(gradients, cut_values) + + result = _solve_bundle_subproblem(gradients, cut_values, f, inverse_hessian) + inner_iteration += 1 + + +class GradientDirection(DirectionComputation): + """ + Plain quasi-Newton direction :math:`d = -W g` from the gradient at the + current iterate only (port of ``NonOptDirectionComputationGradient``). + """ + + def compute( + self, + evaluate, + x, + f, + g, + point_set, + inverse_hessian, + stationarity_radius, + radii_update_check, + ) -> DirectionResult: + result = _solve_bundle_subproblem([g], [f], f, inverse_hessian) + result.radii_update_triggered = radii_update_check(result) + return result + + +class CuttingPlane(DirectionComputation): + """ + Proximal-bundle (cutting-plane) direction computation, the NonOpt default + (port of ``NonOptDirectionComputationCuttingPlane``). The bundle is seeded + with the gradients of nearby points from the point set, with linear terms + set from downshifted cutting planes, and enriched with trial points until a + sufficient model decrease is realized. + """ + + def compute( + self, + evaluate, + x, + f, + g, + point_set, + inverse_hessian, + stationarity_radius, + radii_update_check, + ) -> DirectionResult: + if self.try_gradient_step: + result = self._try_gradient_step( + evaluate, x, f, g, inverse_hessian, radii_update_check + ) + if result is not None: + return result + + gradients = [g] + cut_values = [f] + for point in point_set: + if float((x - point.x).abs().max()) <= stationarity_radius: + gradients.append(point.g) + cut_values.append(self._cut_value(x, f, point, linearize=True)) + + return self._inner_loop( + evaluate, + x, + f, + point_set, + inverse_hessian, + stationarity_radius, + radii_update_check, + gradients, + cut_values, + linearize=True, + ) + + +class GradientCombination(DirectionComputation): + """ + Gradient-sampling-style direction computation (port of + ``NonOptDirectionComputationGradientCombination``). In addition to nearby + points from the point set, gradients are evaluated at randomly sampled + points within the stationarity radius of the current iterate; linear terms + use the downshifted value only. + + :param random_sample_factor: Number of points to sample per subproblem + solve. If at least 1, it is the absolute number of points; otherwise + the number sampled is this factor times the number of variables + (rounded down, at least one point is always sampled). + :type random_sample_factor: float + :param generator: Optional ``torch.Generator`` for reproducible sampling. + :type generator: torch.Generator + """ + + def __init__( + self, + random_sample_factor: float = 10, + generator: torch.Generator = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.random_sample_factor = random_sample_factor + self.generator = generator + + def _sample_count(self, n: int) -> int: + if self.random_sample_factor >= 1.0: + return int(self.random_sample_factor) + return max(1, int(self.random_sample_factor * n)) + + def compute( + self, + evaluate, + x, + f, + g, + point_set, + inverse_hessian, + stationarity_radius, + radii_update_check, + ) -> DirectionResult: + if self.try_gradient_step: + result = self._try_gradient_step( + evaluate, x, f, g, inverse_hessian, radii_update_check + ) + if result is not None: + return result + + gradients = [g] + cut_values = [f] + + def sample(gradients, cut_values): + for _ in range(self._sample_count(x.numel())): + perturbation = ( + torch.rand( + x.shape, dtype=x.dtype, device=x.device, generator=self.generator + ) + * 2.0 + - 1.0 + ) + x_sample = x + stationarity_radius * perturbation + f_sample, g_sample = evaluate(x_sample) + sampled_point = Point(x_sample, f_sample, g_sample) + point_set.add(sampled_point) + gradients.append(g_sample) + cut_values.append(self._cut_value(x, f, sampled_point, linearize=False)) + + for point in point_set: + if float((x - point.x).abs().max()) <= stationarity_radius: + gradients.append(point.g) + cut_values.append(self._cut_value(x, f, point, linearize=False)) + sample(gradients, cut_values) + + return self._inner_loop( + evaluate, + x, + f, + point_set, + inverse_hessian, + stationarity_radius, + radii_update_check, + gradients, + cut_values, + linearize=False, + sample=sample, + ) diff --git a/src/humancompatible/train/dual_optim/nonopt/inverse_hessian.py b/src/humancompatible/train/dual_optim/nonopt/inverse_hessian.py new file mode 100644 index 0000000..b703eaa --- /dev/null +++ b/src/humancompatible/train/dual_optim/nonopt/inverse_hessian.py @@ -0,0 +1,276 @@ +""" +Inverse Hessian approximations used by the NonOpt port, with the self-correcting +quasi-Newton updates of Curtis and Que. + +Ports of ``NonOptApproximateHessianUpdateBFGS``/``DFP`` and +``NonOptSymmetricMatrixDense``/``LimitedMemory`` from +https://github.com/frankecurtis/NonOpt. The classes here maintain the *inverse* +Hessian approximation :math:`W \\approx H^{-1}` directly, since the direction +computation only needs products :math:`W v`. +""" + +import math +from collections import deque + +import torch +from torch import Tensor + + +def self_correcting_scalar( + u: float, + v: float, + w: float, + correction_threshold_1: float, + correction_threshold_2: float, +) -> float: + """ + Computes the self-correcting BFGS damping scalar :math:`\\phi` such that the + corrected gradient displacement :math:`\\tilde y = (1-\\phi) y + \\phi s` + satisfies + :math:`\\langle s, \\tilde y\\rangle / \\langle s, s\\rangle \\geq \\eta_1` and + :math:`\\langle \\tilde y, \\tilde y\\rangle / \\langle s, \\tilde y\\rangle \\leq \\eta_2`. + + Direct port of ``ApproximateHessianUpdateBFGS::evaluateSelfCorrectingScalar``. + + :param u: Squared norm of the iterate displacement, :math:`\\|s\\|_2^2`. + :type u: float + :param v: Inner product :math:`\\langle s, y\\rangle`. + :type v: float + :param w: Squared norm of the gradient displacement, :math:`\\|y\\|_2^2`. + :type w: float + :param correction_threshold_1: Lower-bound threshold :math:`\\eta_1`. + :type correction_threshold_1: float + :param correction_threshold_2: Upper-bound threshold :math:`\\eta_2`. + :type correction_threshold_2: float + :return: Correction scalar in :math:`[0, 1]`. + :rtype: float + """ + eta1, eta2 = correction_threshold_1, correction_threshold_2 + + # scalar for the lower bound on / + scalar1 = 0.0 + if u <= 0.0: + scalar1 = 1.0 + else: + if v / u < eta1: + if eta1 * u - v > 0.0 and u - v > 0.0: + scalar1 = (eta1 * u - v) / (u - v) + else: + scalar1 = 0.0 + if ( + scalar1 > 0.0 + and scalar1**2 * u + scalar1 * (1.0 - scalar1) * v + (1.0 - scalar1) ** 2 * w + > eta1 * (scalar1 * u + (1.0 - scalar1) * v) + ): + scalar1 = 1.0 + + # scalar for the upper bound on / + scalar2 = 0.0 + if v <= 0.0: + scalar2 = 1.0 + elif w / v > eta2: + temporary1 = u - 2.0 * v + w + temporary2 = 2.0 * (v - w) - eta2 * (u - v) + temporary3 = w - eta2 * v + discriminant = temporary2**2 - 4.0 * temporary1 * temporary3 + if temporary3 > 0.0 and discriminant >= 0.0 and -temporary2 + math.sqrt(discriminant) > 0.0: + scalar2 = 2.0 * temporary3 / (-temporary2 + math.sqrt(discriminant)) + else: + scalar2 = 1.0 + + return max(scalar1, scalar2) + + +class InverseHessian: + """ + Base class for inverse Hessian approximations. Implements the displacement + correction and update-skipping logic shared by all updates. + + :param correction_threshold_1: Self-correction lower-bound threshold; if the + update is corrected, the gradient displacement ``y`` is modified so that + ``/`` is at least this value. + :type correction_threshold_1: float + :param correction_threshold_2: Self-correction upper-bound threshold; if the + update is corrected, ``y`` is modified so that ``/`` is at most + this value. + :type correction_threshold_2: float + :param norm_tolerance: Update is skipped if either displacement norm falls + below this tolerance. + :type norm_tolerance: float + :param product_tolerance: Update is skipped unless + `` >= product_tolerance * ||s|| * ||y||``. + :type product_tolerance: float + :param initial_scaling: Whether to scale the initial matrix by + ``/`` at the first successful update. + :type initial_scaling: bool + """ + + def __init__( + self, + correction_threshold_1: float = 1e-08, + correction_threshold_2: float = 1e+08, + norm_tolerance: float = 1e-08, + product_tolerance: float = 1e-20, + initial_scaling: bool = False, + ) -> None: + self.correction_threshold_1 = correction_threshold_1 + self.correction_threshold_2 = correction_threshold_2 + self.norm_tolerance = norm_tolerance + self.product_tolerance = product_tolerance + self.initial_scaling = initial_scaling + self.initial_update_performed = False + + def _corrected_displacements(self, s: Tensor, y: Tensor): + """Returns the corrected gradient displacement, or None if the update + should be skipped.""" + if float(s.norm()) <= self.norm_tolerance: + return None + if float(y.norm()) <= self.norm_tolerance: + return None + scalar = self_correcting_scalar( + float(s.dot(s)), + float(s.dot(y)), + float(y.dot(y)), + self.correction_threshold_1, + self.correction_threshold_2, + ) + if scalar > 0.0: + y = (1.0 - scalar) * y + scalar * s + if float(s.dot(y)) < self.product_tolerance * float(s.norm()) * float(y.norm()): + return None + return y + + def apply(self, v: Tensor) -> Tensor: + """ + Computes the product :math:`W v`. + + :param v: Vector of shape ``(n,)``. + :type v: torch.Tensor + :return: Product :math:`W v`. + :rtype: torch.Tensor + """ + return self.apply_matrix(v.unsqueeze(1)).squeeze(1) + + def apply_matrix(self, V: Tensor) -> Tensor: + """ + Computes the product :math:`W V` column-wise. + + :param V: Matrix of shape ``(n, k)``. + :type V: torch.Tensor + :return: Product :math:`W V` of shape ``(n, k)``. + :rtype: torch.Tensor + """ + raise NotImplementedError + + def update(self, s: Tensor, y: Tensor) -> bool: + """ + Updates the approximation from the iterate displacement ``s`` and the + gradient displacement ``y`` (after self-correction). + + :param s: Iterate displacement of shape ``(n,)``. + :type s: torch.Tensor + :param y: Gradient displacement of shape ``(n,)``. + :type y: torch.Tensor + :return: Whether the update was performed (False if skipped). + :rtype: bool + """ + raise NotImplementedError + + +class LimitedMemoryInverseHessian(InverseHessian): + """ + Limited-memory BFGS inverse Hessian approximation with self-correcting + updates. Products :math:`W V` are computed with the (vectorized) two-loop + recursion; nothing of size :math:`n \\times n` is ever stored. + + :param history_size: Number of curvature pairs kept. + :type history_size: int + """ + + def __init__(self, history_size: int = 20, **kwargs) -> None: + super().__init__(**kwargs) + self.history_size = history_size + self.pairs = deque(maxlen=history_size) + self.gamma = 1.0 + + def reset(self) -> None: + self.pairs.clear() + self.gamma = 1.0 + self.initial_update_performed = False + + def update(self, s: Tensor, y: Tensor) -> bool: + y = self._corrected_displacements(s, y) + if y is None: + return False + if self.initial_scaling and not self.initial_update_performed: + self.gamma = float(s.dot(y)) / float(y.dot(y)) + self.initial_update_performed = True + rho = 1.0 / float(s.dot(y)) + self.pairs.append((s.clone(), y.clone(), rho)) + return True + + def apply_matrix(self, V: Tensor) -> Tensor: + Q = V.clone() + alphas = [] + for s, y, rho in reversed(self.pairs): + alpha = rho * (s @ Q) # shape (k,) + Q -= torch.outer(y, alpha) + alphas.append(alpha) + R = self.gamma * Q + for (s, y, rho), alpha in zip(self.pairs, reversed(alphas)): + beta = rho * (y @ R) + R += torch.outer(s, alpha - beta) + return R + + +class DenseInverseHessian(InverseHessian): + """ + Dense inverse Hessian approximation supporting self-correcting BFGS and DFP + updates. Stores the full ``(n, n)`` matrix; only suitable for problems with + a moderate number of variables. + + :param formula: Quasi-Newton update formula, ``"bfgs"`` or ``"dfp"``. + :type formula: str + """ + + def __init__(self, formula: str = "bfgs", **kwargs) -> None: + super().__init__(**kwargs) + if formula not in ("bfgs", "dfp"): + raise ValueError(f"Unknown quasi-Newton update formula: {formula}!") + self.formula = formula + self.W = None + + def reset(self) -> None: + self.W = None + self.initial_update_performed = False + + def _materialize(self, like: Tensor) -> None: + if self.W is None: + n = like.numel() + self.W = torch.eye(n, dtype=like.dtype, device=like.device) + + def update(self, s: Tensor, y: Tensor) -> bool: + y = self._corrected_displacements(s, y) + if y is None: + return False + self._materialize(s) + if self.initial_scaling and not self.initial_update_performed: + self.W = (float(s.dot(y)) / float(y.dot(y))) * torch.eye( + s.numel(), dtype=s.dtype, device=s.device + ) + self.initial_update_performed = True + if self.formula == "bfgs": + rho = 1.0 / float(s.dot(y)) + Wy = self.W @ y + self.W -= rho * (torch.outer(s, Wy) + torch.outer(Wy, s)) + self.W += rho * (1.0 + rho * float(y.dot(Wy))) * torch.outer(s, s) + else: # dfp + Wy = self.W @ y + self.W -= torch.outer(Wy, Wy) / float(y.dot(Wy)) + self.W += torch.outer(s, s) / float(s.dot(y)) + return True + + def apply_matrix(self, V: Tensor) -> Tensor: + if self.W is None: + return V.clone() + return self.W @ V diff --git a/src/humancompatible/train/dual_optim/nonopt/line_search.py b/src/humancompatible/train/dual_optim/nonopt/line_search.py new file mode 100644 index 0000000..fa584d6 --- /dev/null +++ b/src/humancompatible/train/dual_optim/nonopt/line_search.py @@ -0,0 +1,156 @@ +""" +Line search strategies, ports of ``NonOptLineSearchWeakWolfe`` and +``NonOptLineSearchBacktracking`` from https://github.com/frankecurtis/NonOpt. + +Both searches measure sufficient decrease against a model-based reference value +supplied by the direction computation, +``min(d' H d, max(||G w||^2, ||d||^2))``, rather than the directional +derivative, as nonsmoothness makes the latter unreliable. +""" + +from typing import Callable, NamedTuple, Tuple + +from torch import Tensor + + +class LineSearchResult(NamedTuple): + """Outcome of a line search: accepted stepsize and trial point data.""" + + stepsize: float + x: Tensor + f: float + g: Tensor + + +def weak_wolfe( + evaluate: Callable[[Tensor], Tuple[float, Tensor]], + x: Tensor, + f: float, + g: Tensor, + d: Tensor, + stepsize_previous: float, + decrease_reference: float, + *, + stepsize_initial: float = 1.0, + stepsize_minimum: float = 1e-20, + stepsize_maximum: float = 1e+02, + sufficient_decrease_threshold: float = 1e-10, + sufficient_decrease_fudge_factor: float = 1e-10, + curvature_threshold: float = 9e-01, + curvature_fudge_factor: float = 1e-10, + stepsize_decrease_factor: float = 5e-01, + stepsize_increase_factor: float = 1e+01, + stepsize_bound_tolerance: float = 1e-20, +) -> LineSearchResult: + """ + Weak Wolfe line search along direction ``d``. Brackets a stepsize satisfying + a sufficient decrease condition (relative to ``decrease_reference``) and a + weak curvature condition. If the bracketing interval collapses, any simple + decrease is accepted; otherwise a null step (stepsize 0) is returned. + + :param evaluate: Callable mapping a flat iterate to ``(objective, gradient)``. + :type evaluate: Callable + :param x: Current iterate (flat). + :type x: torch.Tensor + :param f: Objective value at ``x``. + :type f: float + :param g: Gradient at ``x``. + :type g: torch.Tensor + :param d: Search direction. + :type d: torch.Tensor + :param stepsize_previous: Stepsize accepted in the previous iteration; the + initial trial stepsize is ``stepsize_increase_factor`` times this value, + capped by ``stepsize_initial``. + :type stepsize_previous: float + :param decrease_reference: Model decrease reference from the QP subproblem. + :type decrease_reference: float + :return: Accepted stepsize and trial point data. + :rtype: LineSearchResult + """ + directional_derivative = float(g.dot(d)) + lower = stepsize_minimum + upper = stepsize_maximum + stepsize = max( + stepsize_minimum, + min(stepsize_increase_factor * stepsize_previous, min(stepsize_initial, stepsize_maximum)), + ) + + while True: + x_trial = x + stepsize * d + f_trial, g_trial = evaluate(x_trial) + + sufficient_decrease = ( + f_trial - f + <= -sufficient_decrease_threshold * stepsize * decrease_reference + + sufficient_decrease_fudge_factor + ) + if sufficient_decrease: + curvature_condition = ( + float(g_trial.dot(d)) + >= curvature_threshold * directional_derivative - curvature_fudge_factor + ) + if curvature_condition: + return LineSearchResult(stepsize, x_trial, f_trial, g_trial) + + # interval collapsed: accept simple decrease or take a null step + if ( + stepsize <= lower + stepsize_bound_tolerance + or stepsize >= upper - stepsize_bound_tolerance + ): + if f_trial < f: + return LineSearchResult(stepsize, x_trial, f_trial, g_trial) + return LineSearchResult(0.0, x, f, g) + + if sufficient_decrease: + lower = stepsize + else: + upper = stepsize + stepsize = (1.0 - stepsize_decrease_factor) * lower + stepsize_decrease_factor * upper + + +def backtracking( + evaluate: Callable[[Tensor], Tuple[float, Tensor]], + x: Tensor, + f: float, + g: Tensor, + d: Tensor, + stepsize_previous: float, + decrease_reference: float, + *, + stepsize_initial: float = 1.0, + stepsize_minimum: float = 1e-20, + sufficient_decrease_threshold: float = 1e-10, + sufficient_decrease_fudge_factor: float = 1e-10, + stepsize_decrease_factor: float = 5e-01, + stepsize_increase_factor: float = 1e+01, +) -> LineSearchResult: + """ + Backtracking (Armijo) line search along direction ``d``. Same interface and + acceptance reference as :func:`weak_wolfe`, but without a curvature + condition. + + :return: Accepted stepsize and trial point data. + :rtype: LineSearchResult + """ + stepsize = max( + stepsize_minimum, + min(stepsize_increase_factor * stepsize_previous, stepsize_initial), + ) + + while True: + x_trial = x + stepsize * d + f_trial, g_trial = evaluate(x_trial) + + if ( + f_trial - f + <= -sufficient_decrease_threshold * stepsize * decrease_reference + + sufficient_decrease_fudge_factor + ): + return LineSearchResult(stepsize, x_trial, f_trial, g_trial) + + if stepsize <= stepsize_minimum: + if f_trial < f: + return LineSearchResult(stepsize, x_trial, f_trial, g_trial) + return LineSearchResult(0.0, x, f, g) + + stepsize *= stepsize_decrease_factor diff --git a/src/humancompatible/train/dual_optim/nonopt/optimizer.py b/src/humancompatible/train/dual_optim/nonopt/optimizer.py new file mode 100644 index 0000000..7867f2b --- /dev/null +++ b/src/humancompatible/train/dual_optim/nonopt/optimizer.py @@ -0,0 +1,442 @@ +import math +from functools import reduce + +import torch +from torch.optim import Optimizer + +from .direction import CuttingPlane, DirectionResult, GradientCombination, GradientDirection +from .inverse_hessian import DenseInverseHessian, LimitedMemoryInverseHessian +from .line_search import backtracking, weak_wolfe +from .point_set import Point, PointSet + + +class NonOpt(Optimizer): + def __init__( + self, + params, + direction: str = "cutting_plane", + line_search: str = "weak_wolfe", + inverse_hessian: str = "limited_memory", + history_size: int = 20, + *, + stationarity_radius_initialization_factor: float = 1e-01, + stationarity_radius_initialization_minimum: float = 1e-02, + stationarity_radius_update_factor: float = 1e-01, + stationarity_tolerance: float = 1e-04, + stationarity_tolerance_factor: float = 1e+00, + objective_similarity_tolerance: float = 1e-05, + objective_similarity_limit: int = 10, + iterate_norm_tolerance: float = 1e+20, + point_set_options: dict = None, + direction_options: dict = None, + line_search_options: dict = None, + inverse_hessian_options: dict = None, + ) -> None: + defaults = dict( + direction=direction, + line_search=line_search, + inverse_hessian=inverse_hessian, + history_size=history_size, + stationarity_radius_initialization_factor=stationarity_radius_initialization_factor, + stationarity_radius_initialization_minimum=stationarity_radius_initialization_minimum, + stationarity_radius_update_factor=stationarity_radius_update_factor, + stationarity_tolerance=stationarity_tolerance, + stationarity_tolerance_factor=stationarity_tolerance_factor, + objective_similarity_tolerance=objective_similarity_tolerance, + objective_similarity_limit=objective_similarity_limit, + iterate_norm_tolerance=iterate_norm_tolerance, + ) + super().__init__(params, defaults) + + if len(self.param_groups) != 1: + raise ValueError( + "NonOpt doesn't support per-parameter options (parameter groups)" + ) + self._params = self.param_groups[0]["params"] + + if direction == "cutting_plane": + self._direction = CuttingPlane(**(direction_options or {})) + elif direction == "gradient_combination": + self._direction = GradientCombination(**(direction_options or {})) + elif direction == "gradient": + self._direction = GradientDirection(**(direction_options or {})) + else: + raise ValueError(f"Unknown direction computation strategy: {direction}!") + + if line_search == "weak_wolfe": + self._line_search = weak_wolfe + elif line_search == "backtracking": + self._line_search = backtracking + else: + raise ValueError(f"Unknown line search strategy: {line_search}!") + self._line_search_options = line_search_options or {} + + if inverse_hessian == "limited_memory": + self._inverse_hessian = LimitedMemoryInverseHessian( + history_size=history_size, **(inverse_hessian_options or {}) + ) + elif inverse_hessian == "dense": + self._inverse_hessian = DenseInverseHessian( + **(inverse_hessian_options or {}) + ) + else: + raise ValueError( + f"Unknown inverse Hessian approximation: {inverse_hessian}!" + ) + + self._point_set = PointSet(**(point_set_options or {})) + self._numel_cache = None + + # -- flat parameter handling (as in torch.optim.LBFGS) ---------------------- + + def _numel(self): + if self._numel_cache is None: + self._numel_cache = reduce( + lambda total, p: total + p.numel(), self._params, 0 + ) + return self._numel_cache + + def _gather_flat_params(self): + views = [] + for p in self._params: + if p.is_sparse: + view = p.to_dense().view(-1) + else: + view = p.view(-1) + views.append(view) + return torch.cat(views, dim=0).detach().clone() + + def _gather_flat_grad(self): + views = [] + for p in self._params: + if p.grad is None: + view = p.new_zeros(p.numel()) + elif p.grad.is_sparse: + view = p.grad.to_dense().view(-1) + else: + view = p.grad.view(-1) + views.append(view) + return torch.cat(views, dim=0).detach().clone() + + def _set_flat_params(self, x): + offset = 0 + for p in self._params: + numel = p.numel() + p.copy_(x[offset : offset + numel].view_as(p)) + offset += numel + + # -- termination strategy (port of NonOptTerminationBasic) ------------------ + + def _radii_update_check(self, state, result: DirectionResult) -> bool: + group = self.param_groups[0] + threshold = ( + state["stationarity_radius"] + * group["stationarity_tolerance_factor"] + * state["stationarity_reference_current"] + ) + return ( + result.direction_norm_inf <= threshold + and result.combination_norm_inf <= threshold + ) + + def _check_termination(self, state, result: DirectionResult, f: float) -> None: + group = self.param_groups[0] + tolerance = group["stationarity_tolerance"] + reference = state["stationarity_reference_current"] + + if ( + state["stationarity_radius"] <= tolerance + and result.combination_norm_inf + <= tolerance * group["stationarity_tolerance_factor"] * reference + ): + state["status"] = "stationary" + + if f - state["objective_reference"] >= -group[ + "objective_similarity_tolerance" + ] * max(1.0, abs(state["objective_reference"])): + state["objective_similarity_counter"] += 1 + else: + state["objective_similarity_counter"] = max( + 0, state["objective_similarity_counter"] - 1 + ) + state["objective_reference"] = f + + similarity_exceeded = ( + state["objective_similarity_counter"] > group["objective_similarity_limit"] + ) + if state["stationarity_radius"] <= tolerance and similarity_exceeded: + state["status"] = "objective_similarity" + + if state["stationarity_radius"] > tolerance and ( + result.radii_update_triggered + or self._radii_update_check(state, result) + or similarity_exceeded + ): + state["objective_similarity_counter"] = 0 + state["update_radii"] = True + + # -- main step --------------------------------------------------------------- + + @torch.no_grad() + def step(self, closure): + """ + Performs a single NonOpt (outer) iteration: direction computation, + termination/radii check, line search, inverse Hessian update and point + set update. + + :param closure: A closure that reevaluates the model, calls + ``backward()``, and returns the loss. + :type closure: Callable + :return: Loss at the iterate the step started from. + :rtype: torch.Tensor + """ + closure = torch.enable_grad()(closure) + group = self.param_groups[0] + state = self.state[self._params[0]] + + def evaluate(x): + self._set_flat_params(x) + with torch.enable_grad(): + loss = closure() + return float(loss), self._gather_flat_grad() + + # lazy initialization on the first step + if "x" not in state: + state["x"] = self._gather_flat_params() + loss = closure() + state["f"] = float(loss) + state["g"] = self._gather_flat_grad() + gradient_norm_inf = float(state["g"].abs().max()) + state["stationarity_radius"] = max( + group["stationarity_radius_initialization_minimum"], + group["stationarity_radius_initialization_factor"] * gradient_norm_inf, + ) + state["stationarity_reference"] = gradient_norm_inf + state["objective_reference"] = state["f"] + state["objective_similarity_counter"] = 0 + state["stepsize"] = 1.0 + state["iterate_norm_initial"] = float(state["x"].norm()) + state["n_iterations"] = 0 + state["status"] = "running" + else: + loss = None + + if state["status"] != "running": + self._set_flat_params(state["x"]) + return loss + + x, f, g = state["x"], state["f"], state["g"] + state["update_radii"] = False + state["stationarity_reference_current"] = max( + 1.0, state["stationarity_reference"], float(g.abs().max()) + ) + + # divergence check + if float(x.norm()) >= group["iterate_norm_tolerance"] * max( + 1.0, state["iterate_norm_initial"] + ): + state["status"] = "diverged" + self._set_flat_params(x) + return loss + + # direction computation + result = self._direction.compute( + evaluate, + x, + f, + g, + self._point_set, + self._inverse_hessian, + state["stationarity_radius"], + lambda res: self._radii_update_check(state, res), + ) + + # termination and radii update checks + self._check_termination(state, result, f) + if state["status"] != "running": + self._set_flat_params(x) + return loss + if state["update_radii"]: + state["stationarity_radius"] = max( + group["stationarity_tolerance"], + group["stationarity_radius_update_factor"] + * state["stationarity_radius"], + ) + state["stepsize"] = 1.0 + + # line search + search = self._line_search( + evaluate, + x, + f, + g, + result.direction, + state["stepsize"], + result.decrease_reference, + **self._line_search_options, + ) + + # inverse Hessian update (with self-correction; may be skipped) + self._inverse_hessian.update(search.x - x, search.g - g) + + # accept the iterate and update the point set + if search.stepsize > 0.0: + self._point_set.add(Point(x, f, g)) + state["x"], state["f"], state["g"] = ( + search.x.detach().clone(), + search.f, + search.g.detach().clone(), + ) + state["stepsize"] = search.stepsize if search.stepsize > 0.0 else state["stepsize"] + state["n_iterations"] += 1 + self._point_set.update(state["x"], state["stationarity_radius"]) + + self._set_flat_params(state["x"]) + return loss + + # -- inspection helpers ------------------------------------------------------- + + @property + def status(self) -> str: + """ + Solver status: ``"running"`` (or not yet started), ``"stationary"`` + (stationarity established within the requested tolerance), + ``"objective_similarity"`` (insufficient objective improvement at the + final stationarity radius), or ``"diverged"``. + + :return: Status string. + :rtype: str + """ + state = self.state[self._params[0]] + return state.get("status", "running") + + @property + def converged(self) -> bool: + """ + Whether the solver has terminated successfully (by stationarity or by + objective similarity). + + :return: True if the solver has converged. + :rtype: bool + """ + return self.status in ("stationary", "objective_similarity") + + @property + def stationarity_radius(self) -> float: + """ + Current stationarity radius (``math.inf`` before the first step). + + :return: Stationarity radius. + :rtype: float + """ + state = self.state[self._params[0]] + return state.get("stationarity_radius", math.inf) + + +NonOpt.__doc__ = r""" + A PyTorch port of NonOpt (https://frankecurtis.github.io/NonOpt/), an + open-source solver for unconstrained minimization of locally Lipschitz — + possibly nonconvex and nonsmooth — objective functions, by Frank E. Curtis + and collaborators. Reference: Curtis & Zebiane, + https://doi.org/10.48550/arXiv.2503.22826 + + The method combines quasi-Newton (self-correcting BFGS) Hessian + approximations with proximal-bundle (cutting-plane) or gradient-sampling + direction computations and an (inexact) weak Wolfe line search. At each + iteration a convex quadratic subproblem over the convex hull of recently + observed (sub)gradients is solved to obtain the search direction; a + stationarity radius is shrunk adaptively until approximate stationarity is + established. + + The optimizer follows the interface of :class:`torch.optim.LBFGS`: it + requires a closure that re-evaluates the loss and its gradient, and it may + call the closure several times per step (for bundle and line search + evaluations). Like LBFGS, it works with a single parameter group and is + intended for *deterministic* (full-batch) objectives: + + .. code-block:: python + + optimizer = NonOpt(model.parameters()) + + def closure(): + optimizer.zero_grad() + loss = loss_fn(model(input), target) + loss.backward() + return loss + + for _ in range(max_iterations): + optimizer.step(closure) + if optimizer.converged: + break + + .. note:: + + The objective only needs to be differentiable almost everywhere + (``loss.backward()`` must produce *a* subgradient, which autograd + does for the usual nonsmooth primitives such as ``abs``, ``max`` or + ``relu``). Deviations from the C++ reference: no objective scaling is + applied, the quadratic subproblem is solved by an accelerated + projected-gradient method without a trust-region constraint (inactive + at its default value in the C++ implementation), and subgradient + aggregation is not implemented. + + :param params: Iterable of parameters to optimize. + :type params: iterable + :param direction: Direction computation strategy; one of ``cutting_plane`` + (proximal bundle; default, as in NonOpt), ``gradient_combination`` + (gradient sampling), ``gradient`` (plain quasi-Newton). + :type direction: str + :param line_search: Line search strategy; one of ``weak_wolfe`` (default), + ``backtracking``. + :type line_search: str + :param inverse_hessian: Inverse Hessian approximation; one of + ``limited_memory`` (L-BFGS-style two-loop recursion; default), + ``dense`` (explicit matrix, supports BFGS and DFP updates; only for + small problems). + :type inverse_hessian: str + :param history_size: Number of curvature pairs kept by the limited-memory + approximation. + :type history_size: int + :param stationarity_radius_initialization_factor: Factor for initializing + the stationarity radius: the initial radius is the maximum of this value + times the inf-norm of the initial gradient and + `stationarity_radius_initialization_minimum`. + :type stationarity_radius_initialization_factor: float + :param stationarity_radius_initialization_minimum: Minimum initial value of + the stationarity radius. + :type stationarity_radius_initialization_minimum: float + :param stationarity_radius_update_factor: Factor by which the stationarity + radius is multiplied when the radii-update conditions are met. + :type stationarity_radius_update_factor: float + :param stationarity_tolerance: Tolerance for declaring stationarity; the + algorithm reports convergence once the stationarity radius reaches this + value and the minimum-norm gradient combination is small. + :type stationarity_tolerance: float + :param stationarity_tolerance_factor: Factor applied to the stationarity + tolerance/radius in the termination and radii-update tests. + :type stationarity_tolerance_factor: float + :param objective_similarity_tolerance: If consecutive objective values agree + to within this relative tolerance, a counter is increased; reaching + `objective_similarity_limit` triggers a radius decrease or termination. + :type objective_similarity_tolerance: float + :param objective_similarity_limit: Limit for the objective similarity + counter. + :type objective_similarity_limit: int + :param iterate_norm_tolerance: Divergence is declared when the iterate norm + exceeds this value times ``max(1, ||x_0||)``. + :type iterate_norm_tolerance: float + :param point_set_options: Keyword arguments forwarded to + :class:`~humancompatible.train.dual_optim.nonopt.point_set.PointSet`. + :type point_set_options: dict + :param direction_options: Keyword arguments forwarded to the direction + computation strategy (see + :mod:`~humancompatible.train.dual_optim.nonopt.direction`). + :type direction_options: dict + :param line_search_options: Keyword arguments forwarded to the line search + (see :mod:`~humancompatible.train.dual_optim.nonopt.line_search`). + :type line_search_options: dict + :param inverse_hessian_options: Keyword arguments forwarded to the inverse + Hessian approximation (see + :mod:`~humancompatible.train.dual_optim.nonopt.inverse_hessian`). + :type inverse_hessian_options: dict + """ diff --git a/src/humancompatible/train/dual_optim/nonopt/point_set.py b/src/humancompatible/train/dual_optim/nonopt/point_set.py new file mode 100644 index 0000000..2289cd5 --- /dev/null +++ b/src/humancompatible/train/dual_optim/nonopt/point_set.py @@ -0,0 +1,79 @@ +""" +Point set storage and proximity-based update, a port of +``NonOptPointSetUpdateProximity`` from https://github.com/frankecurtis/NonOpt. +""" + +import math +from typing import NamedTuple + +import torch +from torch import Tensor + + +class Point(NamedTuple): + """A previously visited point: iterate, objective value and (sub)gradient.""" + + x: Tensor + f: float + g: Tensor + + +class PointSet: + """ + Set of previously visited points whose gradients may enter the cutting-plane + bundle. Pruned by age and by proximity to the current iterate. + + :param size_factor: If the size of the point set exceeds this factor times + the number of variables, the oldest members are removed. + :type size_factor: float + :param size_maximum: Hard cap on the size of the point set. The C++ default + is infinity; here it defaults to 100 to bound memory, since each point + stores a full gradient copy. + :type size_maximum: int + :param envelope_factor: A point is removed when its distance to the current + iterate exceeds this factor times the stationarity radius. + :type envelope_factor: float + """ + + def __init__( + self, + size_factor: float = 5e-02, + size_maximum: int = 100, + envelope_factor: float = 1e+02, + ) -> None: + self.size_factor = size_factor + self.size_maximum = size_maximum if size_maximum is not None else math.inf + self.envelope_factor = envelope_factor + self.points: list[Point] = [] + + def __len__(self) -> int: + return len(self.points) + + def __iter__(self): + return iter(self.points) + + def add(self, point: Point) -> None: + """Appends a point (its tensors are stored as detached clones).""" + self.points.append( + Point(point.x.detach().clone(), float(point.f), point.g.detach().clone()) + ) + + def update(self, x_current: Tensor, stationarity_radius: float) -> None: + """ + Prunes the point set: removes the oldest members while the set is too + large, then removes all points farther than + ``envelope_factor * stationarity_radius`` from the current iterate. + + :param x_current: Current iterate (flat). + :type x_current: torch.Tensor + :param stationarity_radius: Current stationarity radius. + :type stationarity_radius: float + """ + n = x_current.numel() + limit = min(self.size_factor * n, self.size_maximum) + if len(self.points) > limit: + del self.points[: len(self.points) - max(int(limit), 0)] + radius = self.envelope_factor * stationarity_radius + self.points = [ + p for p in self.points if float(torch.norm(x_current - p.x)) <= radius + ] diff --git a/src/humancompatible/train/dual_optim/nonopt/qp.py b/src/humancompatible/train/dual_optim/nonopt/qp.py new file mode 100644 index 0000000..ade8c67 --- /dev/null +++ b/src/humancompatible/train/dual_optim/nonopt/qp.py @@ -0,0 +1,111 @@ +""" +Solver for the quadratic subproblem arising in the NonOpt direction computations. + +The subproblem is the dual of the (proximal) cutting-plane / gradient-combination +subproblem: + +.. math:: + \\min_{\\omega \\in \\Delta^k} \\; + \\tfrac{1}{2} \\omega^T (G^T W G) \\omega - b^T \\omega, + +where :math:`\\Delta^k` is the unit simplex, the columns of :math:`G` are the +(sub)gradients in the bundle, :math:`W` is an inverse Hessian approximation and +:math:`b` collects the linear (cutting-plane) terms. The primal search direction +is recovered as :math:`d = -W G \\omega`. + +The reference C++ implementation (https://github.com/frankecurtis/NonOpt) solves +this with a specialized dual active-set method; since the subproblem dimension +equals the bundle size (small), an accelerated projected-gradient method with an +exact simplex projection is used here instead. +""" + +import torch +from torch import Tensor + + +def project_onto_simplex(v: Tensor) -> Tensor: + """ + Computes the Euclidean projection of a vector onto the unit simplex + :math:`\\{\\omega : \\omega \\geq 0, \\sum_i \\omega_i = 1\\}`. + + :param v: Vector to project. + :type v: torch.Tensor + :return: Projection of `v` onto the unit simplex. + :rtype: torch.Tensor + """ + k = v.numel() + u, _ = torch.sort(v, descending=True) + cumulative = torch.cumsum(u, dim=0) - 1.0 + indices = torch.arange(1, k + 1, dtype=v.dtype, device=v.device) + positive = u - cumulative / indices > 0 + if not positive.any(): # only possible with non-finite input + return torch.full_like(v, 1.0 / k) + rho = int(torch.nonzero(positive)[-1].item()) + theta = cumulative[rho] / (rho + 1.0) + return torch.clamp(v - theta, min=0.0) + + +def solve_simplex_qp( + Q: Tensor, + b: Tensor, + tol: float = 1e-10, + max_iterations: int = None, +) -> Tensor: + """ + Solves :math:`\\min_{\\omega \\in \\Delta^k} \\tfrac{1}{2}\\omega^T Q \\omega - b^T \\omega` + over the unit simplex with an accelerated projected-gradient (FISTA) method. + + :param Q: Symmetric positive semi-definite matrix of shape ``(k, k)``. + :type Q: torch.Tensor + :param b: Linear term of shape ``(k,)``. + :type b: torch.Tensor + :param tol: Tolerance on the projected-gradient KKT residual. + :type tol: float + :param max_iterations: Iteration limit; defaults to ``max(200, 20 * k)``. + :type max_iterations: int + :return: Approximate solution ``omega`` of shape ``(k,)``. + :rtype: torch.Tensor + """ + k = b.numel() + if k == 1: + return torch.ones_like(b) + + Q = 0.5 * (Q + Q.t()) + + # Lipschitz constant of the gradient; k is small, so an exact eigenvalue is cheap + try: + lipschitz = float(torch.linalg.eigvalsh(Q)[-1]) + except Exception: + lipschitz = float(Q.abs().sum(dim=1).max()) + + if not lipschitz > tol: + # negligible quadratic term: the linear program is solved at a vertex + omega = torch.zeros_like(b) + omega[int(torch.argmax(b))] = 1.0 + return omega + + if max_iterations is None: + max_iterations = max(200, 20 * k) + + omega = torch.full_like(b, 1.0 / k) + accelerated = omega.clone() + momentum = 1.0 + for _ in range(max_iterations): + gradient = Q @ omega - b + kkt_residual = (omega - project_onto_simplex(omega - gradient)).abs().max() + if kkt_residual <= tol: + break + omega_new = project_onto_simplex( + accelerated - (Q @ accelerated - b) / lipschitz + ) + momentum_new = 0.5 * (1.0 + (1.0 + 4.0 * momentum**2) ** 0.5) + accelerated = omega_new + ((momentum - 1.0) / momentum_new) * ( + omega_new - omega + ) + # restart acceleration if it points uphill + if torch.dot(omega_new - omega, gradient) > 0: + accelerated = omega_new.clone() + momentum_new = 1.0 + omega, momentum = omega_new, momentum_new + + return omega diff --git a/tests/test_nonopt.py b/tests/test_nonopt.py new file mode 100644 index 0000000..369aa72 --- /dev/null +++ b/tests/test_nonopt.py @@ -0,0 +1,265 @@ +import math +import unittest + +import torch +from humancompatible.train.dual_optim import NonOpt +from humancompatible.train.dual_optim.nonopt import ( + LimitedMemoryInverseHessian, + project_onto_simplex, + solve_simplex_qp, +) + + +def run_nonopt(x0, objective, max_iterations=500, **options): + """Runs NonOpt on a tensor objective and returns (optimizer, x, f_final).""" + x = torch.nn.Parameter(x0.clone()) + optimizer = NonOpt([x], **options) + + def closure(): + optimizer.zero_grad() + loss = objective(x) + loss.backward() + return loss + + for _ in range(max_iterations): + optimizer.step(closure) + if optimizer.converged: + break + return optimizer, x.detach(), float(objective(x)) + + +def maxq(x): + """MaxQ: f(x) = max_i x_i^2, nonsmooth convex, f* = 0.""" + return (x**2).max() + + +def maxq_x0(n): + x0 = torch.arange(1.0, n + 1.0) + x0[n // 2 :] *= -1.0 + return x0 + + +def chained_lq(x): + """Chained LQ, nonsmooth convex, f* = -(n-1) * sqrt(2).""" + a = -x[:-1] - x[1:] + b = a + (x[:-1] ** 2 + x[1:] ** 2 - 1.0) + return torch.maximum(a, b).sum() + + +class TestSimplexQP(unittest.TestCase): + """Test the simplex projection and the QP subproblem solver.""" + + def test_projection_already_feasible(self): + v = torch.tensor([0.2, 0.3, 0.5], dtype=torch.float64) + self.assertTrue(torch.allclose(project_onto_simplex(v), v)) + + def test_projection_sums_to_one_and_nonnegative(self): + torch.manual_seed(0) + for _ in range(10): + v = torch.randn(7, dtype=torch.float64) * 5 + p = project_onto_simplex(v) + self.assertAlmostEqual(float(p.sum()), 1.0, places=10) + self.assertTrue((p >= 0).all()) + + def test_projection_single_dominant_coordinate(self): + v = torch.tensor([10.0, 0.0, 0.0], dtype=torch.float64) + p = project_onto_simplex(v) + self.assertTrue(torch.allclose(p, torch.tensor([1.0, 0.0, 0.0], dtype=torch.float64))) + + def test_qp_identity_quadratic(self): + # min ½||ω||² over simplex -> uniform + Q = torch.eye(4, dtype=torch.float64) + b = torch.zeros(4, dtype=torch.float64) + omega = solve_simplex_qp(Q, b) + self.assertTrue(torch.allclose(omega, torch.full((4,), 0.25, dtype=torch.float64), atol=1e-6)) + + def test_qp_single_element(self): + Q = torch.tensor([[2.0]], dtype=torch.float64) + b = torch.tensor([1.0], dtype=torch.float64) + omega = solve_simplex_qp(Q, b) + self.assertTrue(torch.allclose(omega, torch.ones(1, dtype=torch.float64))) + + def test_qp_known_solution(self): + # min ½ ωᵀ diag(1, 4) ω over the simplex: ω = (4/5, 1/5) + Q = torch.diag(torch.tensor([1.0, 4.0], dtype=torch.float64)) + b = torch.zeros(2, dtype=torch.float64) + omega = solve_simplex_qp(Q, b) + self.assertTrue( + torch.allclose(omega, torch.tensor([0.8, 0.2], dtype=torch.float64), atol=1e-6) + ) + + def test_qp_linear_only(self): + # negligible quadratic: solution at the vertex maximizing b + Q = torch.zeros(3, 3, dtype=torch.float64) + b = torch.tensor([-1.0, 3.0, 0.5], dtype=torch.float64) + omega = solve_simplex_qp(Q, b) + self.assertEqual(int(torch.argmax(omega)), 1) + self.assertAlmostEqual(float(omega.sum()), 1.0, places=10) + + +class TestLimitedMemoryInverseHessian(unittest.TestCase): + """Test the self-correcting L-BFGS inverse Hessian approximation.""" + + def test_identity_before_updates(self): + W = LimitedMemoryInverseHessian() + v = torch.randn(5) + self.assertTrue(torch.allclose(W.apply(v), v)) + + def test_secant_equation(self): + # after an update with curvature pair (s, y), W y = s must hold + torch.manual_seed(1) + W = LimitedMemoryInverseHessian() + s = torch.randn(6, dtype=torch.float64) + y = s + 0.1 * torch.randn(6, dtype=torch.float64) + if float(s.dot(y)) <= 0: + y = s.clone() + self.assertTrue(W.update(s, y)) + self.assertTrue(torch.allclose(W.apply(y), s, atol=1e-10)) + + def test_update_skipped_on_tiny_displacement(self): + W = LimitedMemoryInverseHessian() + s = torch.full((4,), 1e-12, dtype=torch.float64) + y = torch.full((4,), 1e-12, dtype=torch.float64) + self.assertFalse(W.update(s, y)) + + def test_self_correction_keeps_curvature_positive(self): + # negative curvature pair must be corrected, not produce a singular update + W = LimitedMemoryInverseHessian() + s = torch.tensor([1.0, 0.0], dtype=torch.float64) + y = torch.tensor([-1.0, 0.5], dtype=torch.float64) # s·y < 0 + self.assertTrue(W.update(s, y)) + v = torch.randn(2, dtype=torch.float64) + Wv = W.apply(v) + # W must remain positive definite + self.assertGreater(float(v.dot(Wv)), 0.0) + + +class TestNonOptOnNonsmoothProblems(unittest.TestCase): + """Test convergence on classic nonsmooth test problems from NonOpt.""" + + def test_maxq_cutting_plane(self): + torch.manual_seed(0) + optimizer, x, f = run_nonopt(maxq_x0(10), maxq) + self.assertLess(f, 1e-04) + + def test_maxq_gradient_combination(self): + torch.manual_seed(0) + optimizer, x, f = run_nonopt( + maxq_x0(10), maxq, direction="gradient_combination" + ) + self.assertLess(f, 1e-04) + + def test_maxq_gradient(self): + torch.manual_seed(0) + optimizer, x, f = run_nonopt(maxq_x0(10), maxq, direction="gradient") + self.assertLess(f, 1e-02) + + def test_maxq_backtracking(self): + torch.manual_seed(0) + optimizer, x, f = run_nonopt(maxq_x0(10), maxq, line_search="backtracking") + self.assertLess(f, 1e-04) + + def test_maxq_dense_bfgs(self): + torch.manual_seed(0) + optimizer, x, f = run_nonopt(maxq_x0(10), maxq, inverse_hessian="dense") + self.assertLess(f, 1e-04) + + def test_maxq_dense_dfp(self): + # DFP is less robust than BFGS on nonsmooth problems (hence NonOpt + # defaults to BFGS); it converges more slowly, so use a looser bar. + torch.manual_seed(0) + optimizer, x, f = run_nonopt( + maxq_x0(10), + maxq, + inverse_hessian="dense", + inverse_hessian_options={"formula": "dfp"}, + ) + self.assertLess(f, 1e-03) + + def test_chained_lq(self): + torch.manual_seed(0) + n = 10 + optimizer, x, f = run_nonopt(-0.5 * torch.ones(n), chained_lq) + f_star = -(n - 1) * math.sqrt(2.0) + self.assertLess(f - f_star, 1e-03 * abs(f_star)) + + def test_l1_regression(self): + # piecewise-linear convex: f(w) = ||A w - b||_1 with known minimizer + torch.manual_seed(2) + A = torch.randn(30, 5) + w_star = torch.randn(5) + b = A @ w_star + + optimizer, w, f = run_nonopt(torch.zeros(5), lambda w: (A @ w - b).abs().sum()) + self.assertLess(f, 1e-03) + self.assertTrue(torch.allclose(w, w_star, atol=1e-03)) + + def test_converged_flag_and_status(self): + torch.manual_seed(0) + optimizer, x, f = run_nonopt(maxq_x0(10), maxq, max_iterations=1000) + self.assertTrue(optimizer.converged) + self.assertIn(optimizer.status, ("stationary", "objective_similarity")) + + def test_smooth_problem(self): + # smooth strongly convex sanity check + optimizer, x, f = run_nonopt( + torch.tensor([3.0, -4.0]), lambda x: ((x - 1.0) ** 2).sum() + ) + self.assertTrue(torch.allclose(x, torch.ones(2), atol=1e-03)) + + +class TestNonOptInterface(unittest.TestCase): + """Test torch.optim.Optimizer interface compliance.""" + + def test_works_with_nn_module(self): + torch.manual_seed(3) + model = torch.nn.Linear(4, 1) + A = torch.randn(20, 4) + b = torch.randn(20, 1) + optimizer = NonOpt(model.parameters()) + + def closure(): + optimizer.zero_grad() + loss = (model(A) - b).abs().mean() + loss.backward() + return loss + + initial = float(closure()) + for _ in range(100): + optimizer.step(closure) + if optimizer.converged: + break + final = float(closure()) + self.assertLess(final, initial) + + def test_rejects_multiple_param_groups(self): + p1 = torch.nn.Parameter(torch.zeros(2)) + p2 = torch.nn.Parameter(torch.zeros(2)) + with self.assertRaises(ValueError): + NonOpt([{"params": [p1]}, {"params": [p2], "history_size": 5}]) + + def test_rejects_unknown_strategies(self): + p = torch.nn.Parameter(torch.zeros(2)) + with self.assertRaises(ValueError): + NonOpt([p], direction="unknown") + with self.assertRaises(ValueError): + NonOpt([p], line_search="unknown") + with self.assertRaises(ValueError): + NonOpt([p], inverse_hessian="unknown") + + def test_step_returns_initial_loss(self): + p = torch.nn.Parameter(torch.tensor([2.0])) + optimizer = NonOpt([p]) + + def closure(): + optimizer.zero_grad() + loss = (p**2).sum() + loss.backward() + return loss + + loss = optimizer.step(closure) + self.assertAlmostEqual(float(loss), 4.0, places=6) + + +if __name__ == "__main__": + unittest.main()