diff --git a/src/flowMC/strategy/optimization.py b/src/flowMC/strategy/optimization.py index 7d5f6997..9f9719f4 100644 --- a/src/flowMC/strategy/optimization.py +++ b/src/flowMC/strategy/optimization.py @@ -20,13 +20,18 @@ class AdamOptimization(Strategy): Learning rate for the optimization. noise_level: float = 10 Variance of the noise added to the gradients. + bounds: Float[Array, " n_dim 2"] = jnp.array([[-jnp.inf, jnp.inf]]) + Bounds for the optimization. The optimization will be projected to these bounds. + If bounds has shape (1, 2), it will be broadcast to all dimensions. For n_dim > 1, + passing a (1, 2) array applies the same bound to every dimension. To specify different + bounds per dimension, provide an array of shape (n_dim, 2). """ logpdf: Callable[[Float[Array, " n_dim"], dict], Float] n_steps: int = 100 learning_rate: float = 1e-2 noise_level: float = 10 - bounds: Float[Array, "n_dim 2"] = jnp.array([[-jnp.inf, jnp.inf]]) + bounds: Float[Array, " n_dim 2"] = jnp.array([[-jnp.inf, jnp.inf]]) def __repr__(self): return "AdamOptimization" @@ -37,7 +42,7 @@ def __init__( n_steps: int = 100, learning_rate: float = 1e-2, noise_level: float = 10, - bounds: Float[Array, "n_dim 2"] = jnp.array([[-jnp.inf, jnp.inf]]), + bounds: Float[Array, " n_dim 2"] = jnp.array([[-jnp.inf, jnp.inf]]), ): self.logpdf = logpdf self.n_steps = n_steps @@ -45,6 +50,14 @@ def __init__( self.noise_level = noise_level self.bounds = bounds + # Validate bounds shape + if bounds.ndim != 2 or bounds.shape[1] != 2: + raise ValueError( + f"bounds must have shape (n_dim, 2) or (1, 2), got {bounds.shape}" + ) + # If bounds is (1, 2), it will be broadcast to all dimensions. If not, check compatibility. + # Try to infer n_dim from logpdf signature or initial_position, but here we can't, so warn in runtime. + self.solver = optax.chain( optax.adam(learning_rate=self.learning_rate), ) @@ -58,12 +71,12 @@ def __call__( ) -> tuple[ PRNGKeyArray, dict[str, Resource], - Float[Array, "n_chains n_dim"], + Float[Array, " n_chain n_dim"], ]: - def loss_fn(params: Float[Array, " n_dim"]) -> Float: + def loss_fn(params: Float[Array, " n_dim"], data: dict) -> Float: return -self.logpdf(params, data) - rng_key, optimized_positions = self.optimize( + rng_key, optimized_positions, _ = self.optimize( rng_key, loss_fn, initial_position, data ) @@ -76,6 +89,14 @@ def optimize( initial_position: Float[Array, " n_chain n_dim"], data: dict, ): + # Validate bounds shape against n_dim + n_dim = initial_position.shape[-1] + if not (self.bounds.shape[0] == 1 or self.bounds.shape[0] == n_dim): + raise ValueError( + f"bounds shape {self.bounds.shape} is incompatible with n_dim={n_dim}. " + "Provide bounds of shape (1, 2) for broadcasting or (n_dim, 2) for per-dimension bounds." + ) + """Optimization kernel. This can be used independently of the __call__ method. Args: @@ -85,14 +106,26 @@ def optimize( Objective function to optimize. initial_position: Float[Array, " n_chain n_dim"] Initial positions for the optimization. + data: dict + Data to pass to the objective function. + + Returns: + rng_key: PRNGKeyArray + Updated random key. + optimized_positions: Float[Array, " n_chain n_dim"] + Optimized positions. + final_log_prob: Float[Array, " n_chain"] + Final log-probabilities of the optimized positions. """ grad_fn = jax.jit(jax.grad(objective)) - def _kernel(carry, data): + def _kernel(carry, _step): key, params, opt_state = carry key, subkey = jax.random.split(key) - grad = grad_fn(params) * (1 + jax.random.normal(subkey) * self.noise_level) + grad = grad_fn(params, data) * ( + 1 + jax.random.normal(subkey) * self.noise_level + ) updates, opt_state = self.solver.update(grad, opt_state, params) params = optax.apply_updates(params, updates) params = optax.projections.projection_box( @@ -128,4 +161,4 @@ def _single_optimize( if jnp.isinf(final_log_prob).any() or jnp.isnan(final_log_prob).any(): print("Warning: Optimization accessed infinite or NaN log-probabilities.") - return rng_key, optimized_positions + return rng_key, optimized_positions, final_log_prob diff --git a/test/unit/test_strategies.py b/test/unit/test_strategies.py index 2bcbca8b..f842fa51 100644 --- a/test/unit/test_strategies.py +++ b/test/unit/test_strategies.py @@ -61,10 +61,10 @@ def test_standalone_optimize(self): jax.random.normal(subkey, shape=(self.n_chains, self.n_dim)) * 1 + 10 ) - def loss_fn(params: Float[Array, " n_dim"]) -> Float: + def loss_fn(params: Float[Array, " n_dim"], data: dict = {}) -> Float: return -log_posterior(params, {"data": jnp.arange(self.n_dim)}) - rng_key, optimized_position = self.strategy.optimize( + _, optimized_position, final_log_prob = self.strategy.optimize( key, loss_fn, initial_position, {"data": jnp.arange(self.n_dim)} ) @@ -73,6 +73,9 @@ def loss_fn(params: Float[Array, " n_dim"]) -> Float: jnp.mean(optimized_position, axis=1) < jnp.mean(initial_position, axis=1) ) + assert final_log_prob.shape == (self.n_chains,) + assert jnp.all(jnp.isfinite(final_log_prob)) + class TestLocalStep: def test_take_local_step(self):