From 0bc1e93d656c44b0b1bde5faf025dbf2509a63db Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Wed, 14 May 2025 09:42:43 +0800 Subject: [PATCH 01/14] Fix formatting of dimensions in AdamOptimization class --- src/flowMC/strategy/optimization.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/flowMC/strategy/optimization.py b/src/flowMC/strategy/optimization.py index 7d5f6997..814b2510 100644 --- a/src/flowMC/strategy/optimization.py +++ b/src/flowMC/strategy/optimization.py @@ -26,7 +26,7 @@ class AdamOptimization(Strategy): n_steps: int = 100 learning_rate: float = 1e-2 noise_level: float = 10 - bounds: Float[Array, "n_dim 2"] = jnp.array([[-jnp.inf, jnp.inf]]) + bounds: Float[Array, " n_dim 2"] = jnp.array([[-jnp.inf, jnp.inf]]) def __repr__(self): return "AdamOptimization" @@ -37,7 +37,7 @@ def __init__( n_steps: int = 100, learning_rate: float = 1e-2, noise_level: float = 10, - bounds: Float[Array, "n_dim 2"] = jnp.array([[-jnp.inf, jnp.inf]]), + bounds: Float[Array, " n_dim 2"] = jnp.array([[-jnp.inf, jnp.inf]]), ): self.logpdf = logpdf self.n_steps = n_steps @@ -58,7 +58,7 @@ def __call__( ) -> tuple[ PRNGKeyArray, dict[str, Resource], - Float[Array, "n_chains n_dim"], + Float[Array, " n_chains n_dim"], ]: def loss_fn(params: Float[Array, " n_dim"]) -> Float: return -self.logpdf(params, data) From 1cefa8809eeb3bf61664b0e943da9376d513ba55 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Wed, 14 May 2025 09:42:56 +0800 Subject: [PATCH 02/14] Fix gradient calculation in AdamOptimization to include data in grad_fn --- src/flowMC/strategy/optimization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/flowMC/strategy/optimization.py b/src/flowMC/strategy/optimization.py index 814b2510..3c455040 100644 --- a/src/flowMC/strategy/optimization.py +++ b/src/flowMC/strategy/optimization.py @@ -92,7 +92,7 @@ def _kernel(carry, data): key, params, opt_state = carry key, subkey = jax.random.split(key) - grad = grad_fn(params) * (1 + jax.random.normal(subkey) * self.noise_level) + grad = grad_fn(params, data) * (1 + jax.random.normal(subkey) * self.noise_level) updates, opt_state = self.solver.update(grad, opt_state, params) params = optax.apply_updates(params, updates) params = optax.projections.projection_box( From f20998e6e30036362f8031b50449eda618d5c1e5 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Wed, 14 May 2025 09:43:03 +0800 Subject: [PATCH 03/14] Add final_log_prob return value to AdamOptimization.optimize method --- src/flowMC/strategy/optimization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/flowMC/strategy/optimization.py b/src/flowMC/strategy/optimization.py index 3c455040..5a22b8ea 100644 --- a/src/flowMC/strategy/optimization.py +++ b/src/flowMC/strategy/optimization.py @@ -128,4 +128,4 @@ def _single_optimize( if jnp.isinf(final_log_prob).any() or jnp.isnan(final_log_prob).any(): print("Warning: Optimization accessed infinite or NaN log-probabilities.") - return rng_key, optimized_positions + return rng_key, optimized_positions, final_log_prob From 99b6fa7330e10c3ff2975d5560a6fc5d1f30fee7 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Wed, 14 May 2025 09:44:50 +0800 Subject: [PATCH 04/14] Fix return value unpacking in AdamOptimization.__call__ method --- src/flowMC/strategy/optimization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/flowMC/strategy/optimization.py b/src/flowMC/strategy/optimization.py index 5a22b8ea..d7eb07cd 100644 --- a/src/flowMC/strategy/optimization.py +++ b/src/flowMC/strategy/optimization.py @@ -63,7 +63,7 @@ def __call__( def loss_fn(params: Float[Array, " n_dim"]) -> Float: return -self.logpdf(params, data) - rng_key, optimized_positions = self.optimize( + rng_key, optimized_positions, _ = self.optimize( rng_key, loss_fn, initial_position, data ) From 6f15b5ed147cbbb987de751fc8ada1d65eccb64f Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Wed, 14 May 2025 10:03:40 +0800 Subject: [PATCH 05/14] Formatting --- src/flowMC/strategy/optimization.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/flowMC/strategy/optimization.py b/src/flowMC/strategy/optimization.py index d7eb07cd..d385349f 100644 --- a/src/flowMC/strategy/optimization.py +++ b/src/flowMC/strategy/optimization.py @@ -92,7 +92,9 @@ def _kernel(carry, data): key, params, opt_state = carry key, subkey = jax.random.split(key) - grad = grad_fn(params, data) * (1 + jax.random.normal(subkey) * self.noise_level) + grad = grad_fn(params, data) * ( + 1 + jax.random.normal(subkey) * self.noise_level + ) updates, opt_state = self.solver.update(grad, opt_state, params) params = optax.apply_updates(params, updates) params = optax.projections.projection_box( From f593ed013a2c154e215adc64d05c73bfa916b284 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Wed, 14 May 2025 10:06:23 +0800 Subject: [PATCH 06/14] Fix unpacking of return values in TestOptimizationStrategies.optimize method --- test/unit/test_strategies.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/unit/test_strategies.py b/test/unit/test_strategies.py index 2bcbca8b..f6f2aae6 100644 --- a/test/unit/test_strategies.py +++ b/test/unit/test_strategies.py @@ -64,7 +64,7 @@ def test_standalone_optimize(self): def loss_fn(params: Float[Array, " n_dim"]) -> Float: return -log_posterior(params, {"data": jnp.arange(self.n_dim)}) - rng_key, optimized_position = self.strategy.optimize( + rng_key, optimized_position, _ = self.strategy.optimize( key, loss_fn, initial_position, {"data": jnp.arange(self.n_dim)} ) From 5651d2471c7ce9132ba0ef94d7f587a992614cf5 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Wed, 14 May 2025 10:12:20 +0800 Subject: [PATCH 07/14] Add final_log_prob assertion to AdamOptimization test --- test/unit/test_strategies.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/unit/test_strategies.py b/test/unit/test_strategies.py index f6f2aae6..e8a9a3b6 100644 --- a/test/unit/test_strategies.py +++ b/test/unit/test_strategies.py @@ -64,7 +64,7 @@ def test_standalone_optimize(self): def loss_fn(params: Float[Array, " n_dim"]) -> Float: return -log_posterior(params, {"data": jnp.arange(self.n_dim)}) - rng_key, optimized_position, _ = self.strategy.optimize( + rng_key, optimized_position, final_log_prob = self.strategy.optimize( key, loss_fn, initial_position, {"data": jnp.arange(self.n_dim)} ) @@ -73,6 +73,9 @@ def loss_fn(params: Float[Array, " n_dim"]) -> Float: jnp.mean(optimized_position, axis=1) < jnp.mean(initial_position, axis=1) ) + assert final_log_prob.shape == (self.n_chains,) + assert jnp.all(jnp.isfinite(final_log_prob)) + class TestLocalStep: def test_take_local_step(self): From 0908327e1a72243de0c36be54dbbb992ba9104d8 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Wed, 14 May 2025 10:16:07 +0800 Subject: [PATCH 08/14] Add docstring for optimize method in AdamOptimization class --- src/flowMC/strategy/optimization.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/flowMC/strategy/optimization.py b/src/flowMC/strategy/optimization.py index d385349f..b1634bee 100644 --- a/src/flowMC/strategy/optimization.py +++ b/src/flowMC/strategy/optimization.py @@ -85,6 +85,16 @@ def optimize( Objective function to optimize. initial_position: Float[Array, " n_chain n_dim"] Initial positions for the optimization. + data: dict + Data to pass to the objective function. + + Returns: + rng_key: PRNGKeyArray + Updated random key. + optimized_positions: Float[Array, " n_chain n_dim"] + Optimized positions. + final_log_prob: Float[Array, " n_chain"] + Final log-probabilities of the optimized positions. """ grad_fn = jax.jit(jax.grad(objective)) From d0cb0eed4e1148084d6353b527d4fab284715dea Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Wed, 14 May 2025 10:16:48 +0800 Subject: [PATCH 09/14] Add bounds argument description to AdamOptimization class docstring --- src/flowMC/strategy/optimization.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/flowMC/strategy/optimization.py b/src/flowMC/strategy/optimization.py index b1634bee..c5d79046 100644 --- a/src/flowMC/strategy/optimization.py +++ b/src/flowMC/strategy/optimization.py @@ -20,6 +20,8 @@ class AdamOptimization(Strategy): Learning rate for the optimization. noise_level: float = 10 Variance of the noise added to the gradients. + bounds: Float[Array, " n_dim 2"] = jnp.array([[-jnp.inf, jnp.inf]]) + Bounds for the optimization. The optimization will be projected to these bounds. """ logpdf: Callable[[Float[Array, " n_dim"], dict], Float] From 03267336eb3c35d8aa7a9b3b91991577af07bb05 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Wed, 14 May 2025 10:37:19 +0800 Subject: [PATCH 10/14] Update loss_fn signature to include data parameter in optimization strategy tests --- src/flowMC/strategy/optimization.py | 2 +- test/unit/test_strategies.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/flowMC/strategy/optimization.py b/src/flowMC/strategy/optimization.py index c5d79046..fe75c910 100644 --- a/src/flowMC/strategy/optimization.py +++ b/src/flowMC/strategy/optimization.py @@ -62,7 +62,7 @@ def __call__( dict[str, Resource], Float[Array, " n_chains n_dim"], ]: - def loss_fn(params: Float[Array, " n_dim"]) -> Float: + def loss_fn(params: Float[Array, " n_dim"], data: dict) -> Float: return -self.logpdf(params, data) rng_key, optimized_positions, _ = self.optimize( diff --git a/test/unit/test_strategies.py b/test/unit/test_strategies.py index e8a9a3b6..f842fa51 100644 --- a/test/unit/test_strategies.py +++ b/test/unit/test_strategies.py @@ -61,10 +61,10 @@ def test_standalone_optimize(self): jax.random.normal(subkey, shape=(self.n_chains, self.n_dim)) * 1 + 10 ) - def loss_fn(params: Float[Array, " n_dim"]) -> Float: + def loss_fn(params: Float[Array, " n_dim"], data: dict = {}) -> Float: return -log_posterior(params, {"data": jnp.arange(self.n_dim)}) - rng_key, optimized_position, final_log_prob = self.strategy.optimize( + _, optimized_position, final_log_prob = self.strategy.optimize( key, loss_fn, initial_position, {"data": jnp.arange(self.n_dim)} ) From 332c68721036689f3862c44348408f47a64d836b Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Wed, 14 May 2025 10:50:33 +0800 Subject: [PATCH 11/14] Fix gradient computation in optimize method to prevent data shadowing --- src/flowMC/strategy/optimization.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/flowMC/strategy/optimization.py b/src/flowMC/strategy/optimization.py index fe75c910..dc11152d 100644 --- a/src/flowMC/strategy/optimization.py +++ b/src/flowMC/strategy/optimization.py @@ -98,13 +98,16 @@ def optimize( final_log_prob: Float[Array, " n_chain"] Final log-probabilities of the optimized positions. """ - grad_fn = jax.jit(jax.grad(objective)) + data_dict = data # prevent shadowing inside the scan + grad_fn = jax.jit( + jax.grad(objective, argnums=0) + ) # differentiate w.r.t. params only - def _kernel(carry, data): + def _kernel(carry, _step): key, params, opt_state = carry key, subkey = jax.random.split(key) - grad = grad_fn(params, data) * ( + grad = grad_fn(params, data_dict) * ( 1 + jax.random.normal(subkey) * self.noise_level ) updates, opt_state = self.solver.update(grad, opt_state, params) From fbe05fbf5eb8a1d4cb062d539e6f3f6619f8278f Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Wed, 14 May 2025 10:59:57 +0800 Subject: [PATCH 12/14] Revert some suggestions by coderabbit --- src/flowMC/strategy/optimization.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/flowMC/strategy/optimization.py b/src/flowMC/strategy/optimization.py index dc11152d..fc4d307d 100644 --- a/src/flowMC/strategy/optimization.py +++ b/src/flowMC/strategy/optimization.py @@ -98,16 +98,13 @@ def optimize( final_log_prob: Float[Array, " n_chain"] Final log-probabilities of the optimized positions. """ - data_dict = data # prevent shadowing inside the scan - grad_fn = jax.jit( - jax.grad(objective, argnums=0) - ) # differentiate w.r.t. params only + grad_fn = jax.jit(jax.grad(objective)) def _kernel(carry, _step): key, params, opt_state = carry key, subkey = jax.random.split(key) - grad = grad_fn(params, data_dict) * ( + grad = grad_fn(params, data) * ( 1 + jax.random.normal(subkey) * self.noise_level ) updates, opt_state = self.solver.update(grad, opt_state, params) From d768b0144df6574ff92722416fba86a82849c747 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Wed, 14 May 2025 11:05:43 +0800 Subject: [PATCH 13/14] Enhance bounds validation in AdamOptimization class and update docstring for clarity --- src/flowMC/strategy/optimization.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/src/flowMC/strategy/optimization.py b/src/flowMC/strategy/optimization.py index fc4d307d..63ff4bf8 100644 --- a/src/flowMC/strategy/optimization.py +++ b/src/flowMC/strategy/optimization.py @@ -22,6 +22,9 @@ class AdamOptimization(Strategy): Variance of the noise added to the gradients. bounds: Float[Array, " n_dim 2"] = jnp.array([[-jnp.inf, jnp.inf]]) Bounds for the optimization. The optimization will be projected to these bounds. + If bounds has shape (1, 2), it will be broadcast to all dimensions. For n_dim > 1, + passing a (1, 2) array applies the same bound to every dimension. To specify different + bounds per dimension, provide an array of shape (n_dim, 2). """ logpdf: Callable[[Float[Array, " n_dim"], dict], Float] @@ -47,6 +50,12 @@ def __init__( self.noise_level = noise_level self.bounds = bounds + # Validate bounds shape + if bounds.ndim != 2 or bounds.shape[1] != 2: + raise ValueError(f"bounds must have shape (n_dim, 2) or (1, 2), got {bounds.shape}") + # If bounds is (1, 2), it will be broadcast to all dimensions. If not, check compatibility. + # Try to infer n_dim from logpdf signature or initial_position, but here we can't, so warn in runtime. + self.solver = optax.chain( optax.adam(learning_rate=self.learning_rate), ) @@ -60,7 +69,7 @@ def __call__( ) -> tuple[ PRNGKeyArray, dict[str, Resource], - Float[Array, " n_chains n_dim"], + Float[Array, " n_chain n_dim"], ]: def loss_fn(params: Float[Array, " n_dim"], data: dict) -> Float: return -self.logpdf(params, data) @@ -78,6 +87,14 @@ def optimize( initial_position: Float[Array, " n_chain n_dim"], data: dict, ): + # Validate bounds shape against n_dim + n_dim = initial_position.shape[-1] + if not (self.bounds.shape[0] == 1 or self.bounds.shape[0] == n_dim): + raise ValueError( + f"bounds shape {self.bounds.shape} is incompatible with n_dim={n_dim}. " + "Provide bounds of shape (1, 2) for broadcasting or (n_dim, 2) for per-dimension bounds." + ) + """Optimization kernel. This can be used independently of the __call__ method. Args: From 15e103ee09c92ab61c5d063b63bad5e3a35940bb Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Wed, 14 May 2025 11:10:11 +0800 Subject: [PATCH 14/14] Formatting --- src/flowMC/strategy/optimization.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/flowMC/strategy/optimization.py b/src/flowMC/strategy/optimization.py index 63ff4bf8..9f9719f4 100644 --- a/src/flowMC/strategy/optimization.py +++ b/src/flowMC/strategy/optimization.py @@ -52,7 +52,9 @@ def __init__( # Validate bounds shape if bounds.ndim != 2 or bounds.shape[1] != 2: - raise ValueError(f"bounds must have shape (n_dim, 2) or (1, 2), got {bounds.shape}") + raise ValueError( + f"bounds must have shape (n_dim, 2) or (1, 2), got {bounds.shape}" + ) # If bounds is (1, 2), it will be broadcast to all dimensions. If not, check compatibility. # Try to infer n_dim from logpdf signature or initial_position, but here we can't, so warn in runtime.