Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 41 additions & 8 deletions src/flowMC/strategy/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,18 @@ class AdamOptimization(Strategy):
Learning rate for the optimization.
noise_level: float = 10
Variance of the noise added to the gradients.
bounds: Float[Array, " n_dim 2"] = jnp.array([[-jnp.inf, jnp.inf]])
Bounds for the optimization. The optimization will be projected to these bounds.
If bounds has shape (1, 2), it will be broadcast to all dimensions. For n_dim > 1,
passing a (1, 2) array applies the same bound to every dimension. To specify different
bounds per dimension, provide an array of shape (n_dim, 2).
"""

logpdf: Callable[[Float[Array, " n_dim"], dict], Float]
n_steps: int = 100
learning_rate: float = 1e-2
noise_level: float = 10
bounds: Float[Array, "n_dim 2"] = jnp.array([[-jnp.inf, jnp.inf]])
bounds: Float[Array, " n_dim 2"] = jnp.array([[-jnp.inf, jnp.inf]])

def __repr__(self):
return "AdamOptimization"
Expand All @@ -37,14 +42,22 @@ def __init__(
n_steps: int = 100,
learning_rate: float = 1e-2,
noise_level: float = 10,
bounds: Float[Array, "n_dim 2"] = jnp.array([[-jnp.inf, jnp.inf]]),
bounds: Float[Array, " n_dim 2"] = jnp.array([[-jnp.inf, jnp.inf]]),
):
self.logpdf = logpdf
self.n_steps = n_steps
self.learning_rate = learning_rate
self.noise_level = noise_level
self.bounds = bounds

# Validate bounds shape
if bounds.ndim != 2 or bounds.shape[1] != 2:
raise ValueError(
f"bounds must have shape (n_dim, 2) or (1, 2), got {bounds.shape}"
)
# If bounds is (1, 2), it will be broadcast to all dimensions. If not, check compatibility.
# Try to infer n_dim from logpdf signature or initial_position, but here we can't, so warn in runtime.

self.solver = optax.chain(
optax.adam(learning_rate=self.learning_rate),
)
Expand All @@ -58,12 +71,12 @@ def __call__(
) -> tuple[
PRNGKeyArray,
dict[str, Resource],
Float[Array, "n_chains n_dim"],
Float[Array, " n_chain n_dim"],
]:
Comment thread
thomasckng marked this conversation as resolved.
def loss_fn(params: Float[Array, " n_dim"]) -> Float:
def loss_fn(params: Float[Array, " n_dim"], data: dict) -> Float:

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kazewong This may affect performance, but removing it will cause an issue when calling below.

return -self.logpdf(params, data)

rng_key, optimized_positions = self.optimize(
rng_key, optimized_positions, _ = self.optimize(
rng_key, loss_fn, initial_position, data
)

Expand All @@ -76,6 +89,14 @@ def optimize(
initial_position: Float[Array, " n_chain n_dim"],
data: dict,
):
# Validate bounds shape against n_dim
n_dim = initial_position.shape[-1]
if not (self.bounds.shape[0] == 1 or self.bounds.shape[0] == n_dim):
raise ValueError(
f"bounds shape {self.bounds.shape} is incompatible with n_dim={n_dim}. "
"Provide bounds of shape (1, 2) for broadcasting or (n_dim, 2) for per-dimension bounds."
)

"""Optimization kernel. This can be used independently of the __call__ method.

Args:
Expand All @@ -85,14 +106,26 @@ def optimize(
Objective function to optimize.
initial_position: Float[Array, " n_chain n_dim"]
Initial positions for the optimization.
data: dict
Data to pass to the objective function.

Returns:
rng_key: PRNGKeyArray
Updated random key.
optimized_positions: Float[Array, " n_chain n_dim"]
Optimized positions.
final_log_prob: Float[Array, " n_chain"]
Final log-probabilities of the optimized positions.
"""
grad_fn = jax.jit(jax.grad(objective))

def _kernel(carry, data):
def _kernel(carry, _step):
key, params, opt_state = carry

key, subkey = jax.random.split(key)
grad = grad_fn(params) * (1 + jax.random.normal(subkey) * self.noise_level)
grad = grad_fn(params, data) * (
1 + jax.random.normal(subkey) * self.noise_level
)
updates, opt_state = self.solver.update(grad, opt_state, params)
params = optax.apply_updates(params, updates)
params = optax.projections.projection_box(
Expand Down Expand Up @@ -128,4 +161,4 @@ def _single_optimize(
if jnp.isinf(final_log_prob).any() or jnp.isnan(final_log_prob).any():
print("Warning: Optimization accessed infinite or NaN log-probabilities.")

return rng_key, optimized_positions
return rng_key, optimized_positions, final_log_prob
7 changes: 5 additions & 2 deletions test/unit/test_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,10 @@ def test_standalone_optimize(self):
jax.random.normal(subkey, shape=(self.n_chains, self.n_dim)) * 1 + 10
)

def loss_fn(params: Float[Array, " n_dim"]) -> Float:
def loss_fn(params: Float[Array, " n_dim"], data: dict = {}) -> Float:
return -log_posterior(params, {"data": jnp.arange(self.n_dim)})

rng_key, optimized_position = self.strategy.optimize(
_, optimized_position, final_log_prob = self.strategy.optimize(
key, loss_fn, initial_position, {"data": jnp.arange(self.n_dim)}
)

Expand All @@ -73,6 +73,9 @@ def loss_fn(params: Float[Array, " n_dim"]) -> Float:
jnp.mean(optimized_position, axis=1) < jnp.mean(initial_position, axis=1)
)

assert final_log_prob.shape == (self.n_chains,)
assert jnp.all(jnp.isfinite(final_log_prob))


class TestLocalStep:
def test_take_local_step(self):
Expand Down
Loading