Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 9 additions & 10 deletions src/flowMC/resource/nf_model/NF_proposal.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,22 +91,21 @@ def scan_sample(
def body(carry, data):
(
rng_key,
position_current,
log_prob_current,
log_prob_nf_current,
position_initial,
log_prob_initial,
log_prob_nf_initial,
) = carry
(position_proposed, log_prob_proposal, log_prob_nf_proposal) = data

(position_proposal, log_prob_proposal, log_prob_nf_proposal) = data
rng_key, subkey = random.split(rng_key)
ratio = (log_prob_proposal - log_prob_current) - (
log_prob_nf_proposal - log_prob_nf_current
ratio = (log_prob_proposal - log_prob_initial) - (
log_prob_nf_proposal - log_prob_nf_initial
)
uniform_random = jnp.log(jax.random.uniform(subkey))
do_accept = uniform_random < ratio
position_current = jnp.where(do_accept, position_proposed, position_current)
log_prob_current = jnp.where(do_accept, log_prob_proposal, log_prob_current)
position_current = jnp.where(do_accept, position_proposal, position_initial)
log_prob_current = jnp.where(do_accept, log_prob_proposal, log_prob_initial)
log_prob_nf_current = jnp.where(
do_accept, log_prob_nf_proposal, log_prob_nf_current
do_accept, log_prob_nf_proposal, log_prob_nf_initial
)

return (rng_key, position_current, log_prob_current, log_prob_nf_current), (
Expand Down
3 changes: 3 additions & 0 deletions src/flowMC/resource_strategy_bundle/RQSpline_MALA.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(
n_production_loops: int,
n_epochs: int,
mala_step_size: float = 1e-1,
chain_batch_size: int = 0,
rq_spline_hidden_units: list[int] = [32, 32],
rq_spline_n_bins: int = 8,
rq_spline_n_layers: int = 4,
Expand Down Expand Up @@ -135,6 +136,7 @@ def __init__(
["target_positions", "target_log_prob", "target_local_accs"],
n_local_steps,
thinning=local_thinning,
chain_batch_size=chain_batch_size,
verbose=verbose,
)

Expand All @@ -145,6 +147,7 @@ def __init__(
["target_positions", "target_log_prob", "target_global_accs"],
n_global_steps,
thinning=global_thinning,
chain_batch_size=chain_batch_size,
verbose=verbose,
)

Expand Down
3 changes: 3 additions & 0 deletions src/flowMC/resource_strategy_bundle/RQSpline_MALA_PT.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(
n_production_loops: int,
n_epochs: int,
mala_step_size: float = 1e-1,
chain_batch_size: int = 0,
rq_spline_hidden_units: list[int] = [32, 32],
rq_spline_n_bins: int = 8,
rq_spline_n_layers: int = 4,
Expand Down Expand Up @@ -165,6 +166,7 @@ def __init__(
["target_positions", "target_log_prob", "target_local_accs"],
n_local_steps,
thinning=local_thinning,
chain_batch_size=chain_batch_size,
verbose=verbose,
)

Expand All @@ -175,6 +177,7 @@ def __init__(
["target_positions", "target_log_prob", "target_global_accs"],
n_global_steps,
thinning=global_thinning,
chain_batch_size=chain_batch_size,
verbose=verbose,
)

Expand Down
37 changes: 32 additions & 5 deletions src/flowMC/strategy/take_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from flowMC.strategy.base import Strategy
from jaxtyping import Array, Float, PRNGKeyArray
import jax
import jax.numpy as jnp
import equinox as eqx
from abc import abstractmethod

Expand All @@ -18,6 +19,7 @@ class TakeSteps(Strategy):
n_steps: int
current_position: int
thinning: int
chain_batch_size: int # If vmap over a large number of chains is memory bounded, this splits the computation
verbose: bool

def __init__(
Expand All @@ -28,6 +30,7 @@ def __init__(
buffer_names: list[str],
n_steps: int,
thinning: int = 1,
chain_batch_size: int = 0,
verbose: bool = False,
):
self.logpdf_name = logpdf_name
Expand All @@ -37,6 +40,7 @@ def __init__(
self.n_steps = n_steps
self.current_position = 0
self.thinning = thinning
self.chain_batch_size = chain_batch_size
self.verbose = verbose

@abstractmethod
Expand Down Expand Up @@ -98,11 +102,34 @@ def __call__(

# Filter jit will bypass the compilation of
# the function if not clearing the cache
positions, log_probs, do_accepts = eqx.filter_jit(
eqx.filter_vmap(
jax.tree_util.Partial(self.sample, kernel), in_axes=(0, 0, None, None)
)
)(subkey, initial_position, logpdf, data)
n_chains = initial_position.shape[0]
if self.chain_batch_size > 1 and n_chains > self.chain_batch_size:
positions_list = []
log_probs_list = []
do_accepts_list = []
for i in range(0, n_chains, self.chain_batch_size):
batch_slice = slice(i, min(i + self.chain_batch_size, n_chains))
subkey_batch = subkey[batch_slice]
initial_position_batch = initial_position[batch_slice]
positions_batch, log_probs_batch, do_accepts_batch = eqx.filter_jit(
eqx.filter_vmap(
jax.tree_util.Partial(self.sample, kernel),
in_axes=(0, 0, None, None),
)
)(subkey_batch, initial_position_batch, logpdf, data)
positions_list.append(positions_batch)
log_probs_list.append(log_probs_batch)
do_accepts_list.append(do_accepts_batch)
positions = jnp.concatenate(positions_list, axis=0)
log_probs = jnp.concatenate(log_probs_list, axis=0)
do_accepts = jnp.concatenate(do_accepts_list, axis=0)
else:
positions, log_probs, do_accepts = eqx.filter_jit(
eqx.filter_vmap(
jax.tree_util.Partial(self.sample, kernel),
in_axes=(0, 0, None, None),
)
)(subkey, initial_position, logpdf, data)

positions = positions[:, :: self.thinning]
log_probs = log_probs[:, :: self.thinning]
Expand Down
75 changes: 49 additions & 26 deletions test/unit/test_strategies.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import jax
import jax.numpy as jnp
from jaxtyping import Array, Float
import pytest

from flowMC.resource.nf_model.rqSpline import MaskedCouplingRQSpline
from flowMC.resource.optimizer import Optimizer
Expand Down Expand Up @@ -78,22 +79,19 @@ def loss_fn(params: Float[Array, " n_dim"], data: dict = {}) -> Float:


class TestLocalStep:
def test_take_local_step(self):
@pytest.fixture(autouse=True)
def setup(self):
n_chains = 5
n_steps = 25
n_dims = 2
n_batch = 5

test_position = Buffer("test_position", (n_chains, n_steps, n_dims), 1)
test_log_prob = Buffer("test_log_prob", (n_chains, n_steps), 1)
test_acceptance = Buffer("test_acceptance", (n_chains, n_steps), 1)

mala_kernel = MALA(1.0)
grw_kernel = GaussianRandomWalk(1.0)
hmc_kernel = HMC(jnp.eye(n_dims), 0.1, 10)

logpdf = LogPDF(log_posterior, n_dims=n_dims)

sampler_state = State(
{
"test_position": "test_position",
Expand All @@ -102,8 +100,10 @@ def test_take_local_step(self):
},
name="sampler_state",
)

resources = {
self.n_batch = n_batch
self.n_dims = n_dims
self.test_position = test_position
self.resources = {
"test_position": test_position,
"test_log_prob": test_log_prob,
"test_acceptance": test_acceptance,
Expand All @@ -114,51 +114,74 @@ def test_take_local_step(self):
"sampler_state": sampler_state,
}

def test_take_local_step(self):
strategy = TakeSerialSteps(
"logpdf",
"MALA",
"sampler_state",
["test_position", "test_log_prob", "test_acceptance"],
n_batch,
self.n_batch,
)
key = jax.random.PRNGKey(42)
positions = test_position.data[:, 0]

for i in range(n_batch):
positions = self.test_position.data[:, 0]
for _ in range(self.n_batch):
key, subkey1, subkey2 = jax.random.split(key, 3)
_, resources, positions = strategy(
_, self.resources, positions = strategy(
rng_key=subkey1,
resources=resources,
resources=self.resources,
initial_position=positions,
data={"data": jnp.arange(n_dims)},
data={"data": jnp.arange(self.n_dims)},
)

key, subkey1, subkey2 = jax.random.split(key, 3)
strategy.set_current_position(0)
_, resources, positions = strategy(
_, self.resources, positions = strategy(
rng_key=subkey1,
resources=resources,
resources=self.resources,
initial_position=positions,
data={"data": jnp.arange(n_dims)},
data={"data": jnp.arange(self.n_dims)},
)

key, subkey1, subkey2 = jax.random.split(key, 3)
strategy.kernel_name = "GRW"
strategy.set_current_position(0)
_, resources, positions = strategy(
_, self.resources, positions = strategy(
rng_key=subkey1,
resources=resources,
resources=self.resources,
initial_position=positions,
data={"data": jnp.arange(n_dims)},
data={"data": jnp.arange(self.n_dims)},
)

strategy.kernel_name = "HMC"
_, resources, positions = strategy(
_, self.resources, positions = strategy(
rng_key=subkey1,
resources=resources,
resources=self.resources,
initial_position=positions,
data={"data": jnp.arange(n_dims)},
data={"data": jnp.arange(self.n_dims)},
)

def test_take_local_step_chain_batch_size(self):
# Use a chain_batch_size smaller than the number of chains to trigger batching logic
chain_batch_size = 2
strategy = TakeSerialSteps(
"logpdf",
"MALA",
"sampler_state",
["test_position", "test_log_prob", "test_acceptance"],
self.n_batch,
chain_batch_size=chain_batch_size,
)
key = jax.random.PRNGKey(42)
positions = self.test_position.data[:, 0]
# Run the strategy, which should use batching internally
_, _, final_positions = strategy(
rng_key=key,
resources=self.resources,
initial_position=positions,
data={"data": jnp.arange(self.n_dims)},
)
# Check that the output shape is correct
assert final_positions.shape == (positions.shape[0], positions.shape[1])
# Optionally, check that the buffer was updated for all chains
assert isinstance(test_position := self.resources["test_position"], Buffer)
assert test_position.data.shape[0] == positions.shape[0]


class TestNFStrategies:
Expand Down
Loading