Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 37 additions & 5 deletions src/aspire/samplers/smc/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@

logger = logging.getLogger(__name__)

DEFAULT_BETA_TOLERANCE = 1e-8


class BetaScheduleError(RuntimeError):
pass


class SMCSampler(MCMCSampler):
"""Base class for Sequential Monte Carlo samplers.
Expand Down Expand Up @@ -121,7 +127,7 @@ def determine_beta(
beta_step: float,
min_beta_step: float,
max_beta_step: float = 1.0,
beta_tolerance: float = 1e-6,
beta_tolerance: float = DEFAULT_BETA_TOLERANCE,
) -> tuple[float, float]:
"""Determine the next beta value.

Expand All @@ -146,6 +152,12 @@ def determine_beta(
The new beta value.
min_beta_step : float
The new minimum beta step size if adaptive_min_beta_step is True.

Raises
------
BetaScheduleError
If adaptive beta is enabled and the determined beta does not increase
from the previous beta.
"""
if not self.adaptive:
beta += beta_step
Expand All @@ -158,9 +170,10 @@ def determine_beta(
eff_beta_max = effective_sample_size(
samples.log_weights(beta_max)
) / len(samples)
if eff_beta_max >= self.current_target_efficiency(beta_prev):
current_eff = self.current_target_efficiency(beta_prev)
if eff_beta_max >= current_eff:
beta_min = 1.0
target_eff = self.current_target_efficiency(beta_prev)
target_eff = current_eff
while beta_max - beta_min > beta_tolerance:
beta_try = 0.5 * (beta_max + beta_min)
eff = effective_sample_size(
Expand All @@ -171,13 +184,32 @@ def determine_beta(
else:
beta_max = beta_try
beta_star = beta_min
if beta_star <= beta_prev + beta_tolerance and beta_prev < 1.0:
logger.warning(
"Adaptive beta search could not find a beta above %.6g "
"that satisfies the target efficiency %.3f within "
"tolerance %.1e; beta may remain unchanged. "
"Consider decreasing beta_tolerance or target_efficiency.",
beta_prev,
target_eff,
beta_tolerance,
)

if self.adaptive_min_beta_step:
min_beta_step = (
min_beta_step * (1 - beta_prev) / (1 - beta_star)
)
beta = max(beta_star, beta_prev + min_beta_step)
beta = min(beta, beta_prev + max_beta_step, 1.0)
if beta == beta_prev:
raise BetaScheduleError(
f"Beta did not increase from previous value {beta:.6g}. "
"Adaptive beta search may have failed to find a suitable beta. "
f"Consider adjusting beta_tolerance ({beta_tolerance}), "
f"min_beta_step ({min_beta_step}) or "
f"target_efficiency ({target_eff}) "
"(values may be adaptive)."
)
return beta, min_beta_step

@track_calls
Expand All @@ -197,7 +229,7 @@ def sample(
checkpoint_file_path: str | None = None,
resume_from: str | bytes | dict | None = None,
store_sample_history: bool = True,
beta_tolerance: float = 1e-6,
beta_tolerance: float = DEFAULT_BETA_TOLERANCE,
) -> SMCSamples:
"""Sample using the SMC sampler.

Expand Down Expand Up @@ -250,7 +282,7 @@ def sample(
:code:`self.history.sample_history`. Default is True.
beta_tolerance : float, optional
Tolerance for determining convergence of beta when using adaptive
beta. Default is 1e-6.
beta. Default is given by :code:`DEFAULT_BETA_TOLERANCE`.

Returns
-------
Expand Down
Loading