From 61087d862169569a5ca57977a841bd6bfd9dbc5b Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Mon, 18 Aug 2025 11:28:38 -0400 Subject: [PATCH 1/3] add tests --- .../resource/model/flowmatching/base.py | 6 +- test/unit/test_flowmatching.py | 214 ++++++++++++++++++ 2 files changed, 217 insertions(+), 3 deletions(-) create mode 100644 test/unit/test_flowmatching.py diff --git a/src/flowMC/resource/model/flowmatching/base.py b/src/flowMC/resource/model/flowmatching/base.py index 00620b37..08935cba 100644 --- a/src/flowMC/resource/model/flowmatching/base.py +++ b/src/flowMC/resource/model/flowmatching/base.py @@ -17,9 +17,9 @@ class Solver(eqx.Module): model: MLP # Shape should be [input_dim + t_dim, hiddens, output_dim] method: AbstractSolver - def __init__(self, model: MLP, method: str = "dopri5"): + def __init__(self, model: MLP, method: AbstractSolver = Dopri5()): self.model = model - self.method = Dopri5() + self.method = method def sample( self, rng_key: PRNGKeyArray, n_samples: int, dt: Float = 1e-1 @@ -62,7 +62,7 @@ def log_prob(self, x1: Float[Array, " n_dims"], dt: Float = 1e-1) -> Float: def model_wrapper( t: Float, x: Float[Array, " n_dims"], args: PyTree - ) -> list[Float[Array, " ,,,"]]: + ) -> list[Float[Array, " ..."]]: """Wrapper for the model to be used in the ODE solver. The output shape should be [n_dims, 1]. diff --git a/test/unit/test_flowmatching.py b/test/unit/test_flowmatching.py new file mode 100644 index 00000000..5de44b89 --- /dev/null +++ b/test/unit/test_flowmatching.py @@ -0,0 +1,214 @@ +import jax +import jax.numpy as jnp +import pytest + +from flowMC.resource.model.flowmatching.base import ( + FlowMatchingModel, + Solver, + Path, + CondOTScheduler, +) +from flowMC.resource.model.common import MLP +from diffrax import Dopri5 +import equinox as eqx + +def get_simple_mlp(n_input, n_hidden, n_output, key): + # Simple 2-layer MLP for testing + # shape is a list: [input_dim, hidden_dim(s), output_dim] + shape = [n_input] + ([n_hidden] if isinstance(n_hidden, int) else list(n_hidden)) + [n_output] + return MLP( + shape=shape, + key=key, + activation=jax.nn.swish + ) + +def test_flowmatchingmodel_sample_and_log_prob(): + key = jax.random.PRNGKey(42) + n_dim = 2 + n_hidden = 8 + + # Setup model components + mlp = get_simple_mlp(n_input=n_dim+1, n_hidden=n_hidden, n_output=n_dim, key=key) + solver = Solver(model=mlp, method=Dopri5()) + scheduler = CondOTScheduler() + path = Path(scheduler=scheduler) + + # Create FlowMatchingModel + model = FlowMatchingModel( + solver=solver, + path=path, + data_mean=jnp.zeros(n_dim), + data_cov=jnp.eye(n_dim) + ) + + # Test sampling + rng, subkey = jax.random.split(key) + n_samples = 4 + samples = model.sample(subkey, n_samples) + assert samples.shape == (n_samples, n_dim) + assert jnp.isfinite(samples).all() + + # Test log_prob + logp = eqx.filter_vmap(model.log_prob)(samples) + assert logp.shape == (n_samples,1) + assert jnp.isfinite(logp).all() + +def test_flowmatchingmodel_train_step_and_epoch(): + import optax + + key = jax.random.PRNGKey(123) + n_dim = 2 + n_hidden = 8 + n_batch = 5 + + # Setup model components + mlp = get_simple_mlp(n_input=n_dim+1, n_hidden=n_hidden, n_output=n_dim, key=key) + solver = Solver(model=mlp, method=Dopri5()) + scheduler = CondOTScheduler() + path = Path(scheduler=scheduler) + + # Create FlowMatchingModel + model = FlowMatchingModel( + solver=solver, + path=path, + data_mean=jnp.zeros(n_dim), + data_cov=jnp.eye(n_dim) + ) + + # Dummy data for training + x0 = jax.random.normal(key, (n_batch, n_dim)) + x1 = jax.random.normal(key, (n_batch, n_dim)) + t = jax.random.uniform(key, (n_batch, 1)) + + # Prepare optimizer + optim = optax.adam(learning_rate=1e-3) + state = optim.init(eqx.filter(model, eqx.is_array)) + + # Test train_step + std = jnp.sqrt(jnp.diag(model.data_cov)) + x1_whitened = (x1 - model.data_mean) / std + x_t, dx_t = model.path.sample(x0, x1_whitened, t) + loss, model2, state2 = model.train_step(x_t, t, dx_t, optim, state) + assert jnp.isfinite(loss) + assert isinstance(model2, FlowMatchingModel) + + # Test train_epoch + data = (x0, x1, t) + loss_epoch, model3, state3 = model.train_epoch(key, optim, state, data, batch_size=n_batch) + assert jnp.isfinite(loss_epoch) + assert isinstance(model3, FlowMatchingModel) + +def test_solver_sample_and_log_prob_shapes_and_finiteness(): + key = jax.random.PRNGKey(0) + n_dim = 3 + n_hidden = 4 + n_samples = 7 + mlp = get_simple_mlp(n_input=n_dim+1, n_hidden=n_hidden, n_output=n_dim, key=key) + solver = Solver(model=mlp, method=Dopri5()) + # Test sample + samples = solver.sample(key, n_samples) + assert samples.shape == (n_samples, n_dim) + assert jnp.isfinite(samples).all() + # Test log_prob + x1 = jax.random.normal(key, (n_dim,)) + logp = solver.log_prob(x1) + logp_arr = jnp.asarray(logp) + assert logp_arr.size == 1 + assert jnp.isfinite(logp_arr).all() + +def test_solver_sample_various_dt(): + key = jax.random.PRNGKey(1) + n_dim = 2 + mlp = get_simple_mlp(n_input=n_dim+1, n_hidden=4, n_output=n_dim, key=key) + solver = Solver(model=mlp, method=Dopri5()) + for dt in [1e-2, 1e-1, 0.5]: + samples = solver.sample(key, 3, dt=dt) + assert samples.shape == (3, n_dim) + assert jnp.isfinite(samples).all() + +def test_path_sample_shapes_and_values(): + n_dim = 2 + scheduler = CondOTScheduler() + path = Path(scheduler=scheduler) + x0 = jnp.ones((5, n_dim)) + x1 = jnp.zeros((5, n_dim)) + for t_val in [0.0, 0.5, 1.0]: + t = jnp.full((5, 1), t_val) + x_t, dx_t = path.sample(x0, x1, t) + assert x_t.shape == (5, n_dim) + assert dx_t.shape == (5, n_dim) + +def test_condotscheduler_call_output(): + sched = CondOTScheduler() + for t in [0.0, 1.0, 0.5, -0.1, 1.1]: + out = sched(jnp.array(t)) + assert isinstance(out, tuple) + assert len(out) == 4 + assert all(isinstance(float(x), float) for x in out) + +def test_flowmatchingmodel_sample_and_log_prob_various_shapes(): + key = jax.random.PRNGKey(2) + n_dim = 2 + mlp = get_simple_mlp(n_input=n_dim+1, n_hidden=4, n_output=n_dim, key=key) + solver = Solver(model=mlp, method=Dopri5()) + path = Path(scheduler=CondOTScheduler()) + model = FlowMatchingModel(solver=solver, path=path) + for n_samples in [1, 5, 10]: + samples = model.sample(key, n_samples) + assert samples.shape == (n_samples, n_dim) + assert jnp.isfinite(samples).all() + logp = eqx.filter_vmap(model.log_prob)(samples) + assert logp.shape[0] == n_samples + assert jnp.isfinite(logp).all() + +def test_flowmatchingmodel_log_prob_edge_cases(): + key = jax.random.PRNGKey(3) + n_dim = 2 + mlp = get_simple_mlp(n_input=n_dim+1, n_hidden=4, n_output=n_dim, key=key) + solver = Solver(model=mlp, method=Dopri5()) + path = Path(scheduler=CondOTScheduler()) + model = FlowMatchingModel(solver=solver, path=path) + # Edge cases: zeros, large values + for arr in [jnp.zeros(n_dim), 1e6 * jnp.ones(n_dim), -1e6 * jnp.ones(n_dim)]: + logp = model.log_prob(arr) + logp_arr = jnp.asarray(logp) + assert logp_arr.size == 1 + assert jnp.isfinite(logp_arr).all() or jnp.isnan(logp_arr).all() # may be nan for extreme values + +def test_flowmatchingmodel_save_and_load(tmp_path): + key = jax.random.PRNGKey(4) + n_dim = 2 + mlp = get_simple_mlp(n_input=n_dim+1, n_hidden=4, n_output=n_dim, key=key) + solver = Solver(model=mlp, method=Dopri5()) + path_obj = Path(scheduler=CondOTScheduler()) + model = FlowMatchingModel(solver=solver, path=path_obj) + # Save and load + save_path = str(tmp_path / "test_model") + model.save_model(save_path) + loaded = model.load_model(save_path) + # Check that loaded model produces same output for same input + x = jax.random.normal(key, (2, n_dim)) + assert jnp.allclose(eqx.filter_vmap(model.log_prob)(x), eqx.filter_vmap(loaded.log_prob)(x)) + +def test_flowmatchingmodel_properties(): + key = jax.random.PRNGKey(5) + n_dim = 3 + mlp = get_simple_mlp(n_input=n_dim+1, n_hidden=4, n_output=n_dim, key=key) + solver = Solver(model=mlp, method=Dopri5()) + path = Path(scheduler=CondOTScheduler()) + mean = jnp.arange(n_dim) + cov = jnp.eye(n_dim) * 2 + model = FlowMatchingModel(solver=solver, path=path, data_mean=mean, data_cov=cov) + assert model.n_features == n_dim + assert jnp.allclose(model.data_mean, mean) + assert jnp.allclose(model.data_cov, cov) + +def test_flowmatchingmodel_print_parameters_notimplemented(): + key = jax.random.PRNGKey(6) + n_dim = 2 + mlp = get_simple_mlp(n_input=n_dim+1, n_hidden=4, n_output=n_dim, key=key) + solver = Solver(model=mlp, method=Dopri5()) + path = Path(scheduler=CondOTScheduler()) + model = FlowMatchingModel(solver=solver, path=path) + with pytest.raises(NotImplementedError): + model.print_parameters() From 8a2ced08ce10e6aee85c3bf6c0d7b8c6173043e4 Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Mon, 18 Aug 2025 11:31:29 -0400 Subject: [PATCH 2/3] Update test_flowmatching.py --- test/unit/test_flowmatching.py | 309 ++++++++++++++------------------- 1 file changed, 135 insertions(+), 174 deletions(-) diff --git a/test/unit/test_flowmatching.py b/test/unit/test_flowmatching.py index 5de44b89..02a1ae3c 100644 --- a/test/unit/test_flowmatching.py +++ b/test/unit/test_flowmatching.py @@ -11,10 +11,10 @@ from flowMC.resource.model.common import MLP from diffrax import Dopri5 import equinox as eqx +import optax def get_simple_mlp(n_input, n_hidden, n_output, key): - # Simple 2-layer MLP for testing - # shape is a list: [input_dim, hidden_dim(s), output_dim] + """Simple 2-layer MLP for testing.""" shape = [n_input] + ([n_hidden] if isinstance(n_hidden, int) else list(n_hidden)) + [n_output] return MLP( shape=shape, @@ -22,138 +22,101 @@ def get_simple_mlp(n_input, n_hidden, n_output, key): activation=jax.nn.swish ) -def test_flowmatchingmodel_sample_and_log_prob(): - key = jax.random.PRNGKey(42) - n_dim = 2 - n_hidden = 8 - - # Setup model components - mlp = get_simple_mlp(n_input=n_dim+1, n_hidden=n_hidden, n_output=n_dim, key=key) - solver = Solver(model=mlp, method=Dopri5()) - scheduler = CondOTScheduler() - path = Path(scheduler=scheduler) - - # Create FlowMatchingModel - model = FlowMatchingModel( - solver=solver, - path=path, - data_mean=jnp.zeros(n_dim), - data_cov=jnp.eye(n_dim) - ) +############################## +# Solver Tests +############################## + +class TestSolver: + @pytest.fixture + def solver(self): + key = jax.random.PRNGKey(0) + n_dim = 3 + n_hidden = 4 + mlp = get_simple_mlp(n_input=n_dim+1, n_hidden=n_hidden, n_output=n_dim, key=key) + return Solver(model=mlp, method=Dopri5()), key, n_dim + + def test_sample_shape_and_finiteness(self, solver): + solver, key, n_dim = solver + n_samples = 7 + samples = solver.sample(key, n_samples) + assert samples.shape == (n_samples, n_dim) + assert jnp.isfinite(samples).all() - # Test sampling - rng, subkey = jax.random.split(key) - n_samples = 4 - samples = model.sample(subkey, n_samples) - assert samples.shape == (n_samples, n_dim) - assert jnp.isfinite(samples).all() - - # Test log_prob - logp = eqx.filter_vmap(model.log_prob)(samples) - assert logp.shape == (n_samples,1) - assert jnp.isfinite(logp).all() - -def test_flowmatchingmodel_train_step_and_epoch(): - import optax - - key = jax.random.PRNGKey(123) - n_dim = 2 - n_hidden = 8 - n_batch = 5 - - # Setup model components - mlp = get_simple_mlp(n_input=n_dim+1, n_hidden=n_hidden, n_output=n_dim, key=key) - solver = Solver(model=mlp, method=Dopri5()) - scheduler = CondOTScheduler() - path = Path(scheduler=scheduler) - - # Create FlowMatchingModel - model = FlowMatchingModel( - solver=solver, - path=path, - data_mean=jnp.zeros(n_dim), - data_cov=jnp.eye(n_dim) - ) + def test_log_prob_shape_and_finiteness(self, solver): + solver, key, n_dim = solver + x1 = jax.random.normal(key, (n_dim,)) + logp = solver.log_prob(x1) + logp_arr = jnp.asarray(logp) + assert logp_arr.size == 1 + assert jnp.isfinite(logp_arr).all() - # Dummy data for training - x0 = jax.random.normal(key, (n_batch, n_dim)) - x1 = jax.random.normal(key, (n_batch, n_dim)) - t = jax.random.uniform(key, (n_batch, 1)) - - # Prepare optimizer - optim = optax.adam(learning_rate=1e-3) - state = optim.init(eqx.filter(model, eqx.is_array)) - - # Test train_step - std = jnp.sqrt(jnp.diag(model.data_cov)) - x1_whitened = (x1 - model.data_mean) / std - x_t, dx_t = model.path.sample(x0, x1_whitened, t) - loss, model2, state2 = model.train_step(x_t, t, dx_t, optim, state) - assert jnp.isfinite(loss) - assert isinstance(model2, FlowMatchingModel) - - # Test train_epoch - data = (x0, x1, t) - loss_epoch, model3, state3 = model.train_epoch(key, optim, state, data, batch_size=n_batch) - assert jnp.isfinite(loss_epoch) - assert isinstance(model3, FlowMatchingModel) - -def test_solver_sample_and_log_prob_shapes_and_finiteness(): - key = jax.random.PRNGKey(0) - n_dim = 3 - n_hidden = 4 - n_samples = 7 - mlp = get_simple_mlp(n_input=n_dim+1, n_hidden=n_hidden, n_output=n_dim, key=key) - solver = Solver(model=mlp, method=Dopri5()) - # Test sample - samples = solver.sample(key, n_samples) - assert samples.shape == (n_samples, n_dim) - assert jnp.isfinite(samples).all() - # Test log_prob - x1 = jax.random.normal(key, (n_dim,)) - logp = solver.log_prob(x1) - logp_arr = jnp.asarray(logp) - assert logp_arr.size == 1 - assert jnp.isfinite(logp_arr).all() - -def test_solver_sample_various_dt(): - key = jax.random.PRNGKey(1) - n_dim = 2 - mlp = get_simple_mlp(n_input=n_dim+1, n_hidden=4, n_output=n_dim, key=key) - solver = Solver(model=mlp, method=Dopri5()) - for dt in [1e-2, 1e-1, 0.5]: + @pytest.mark.parametrize("dt", [1e-2, 1e-1, 0.5]) + def test_sample_various_dt(self, solver, dt): + solver, key, n_dim = solver samples = solver.sample(key, 3, dt=dt) assert samples.shape == (3, n_dim) assert jnp.isfinite(samples).all() -def test_path_sample_shapes_and_values(): - n_dim = 2 - scheduler = CondOTScheduler() - path = Path(scheduler=scheduler) - x0 = jnp.ones((5, n_dim)) - x1 = jnp.zeros((5, n_dim)) - for t_val in [0.0, 0.5, 1.0]: - t = jnp.full((5, 1), t_val) - x_t, dx_t = path.sample(x0, x1, t) - assert x_t.shape == (5, n_dim) - assert dx_t.shape == (5, n_dim) - -def test_condotscheduler_call_output(): - sched = CondOTScheduler() - for t in [0.0, 1.0, 0.5, -0.1, 1.1]: +############################## +# Path & Scheduler Tests +############################## + +class TestPathAndScheduler: + def test_path_sample_shapes_and_values(self): + n_dim = 2 + scheduler = CondOTScheduler() + path = Path(scheduler=scheduler) + x0 = jnp.ones((5, n_dim)) + x1 = jnp.zeros((5, n_dim)) + for t_val in [0.0, 0.5, 1.0]: + t = jnp.full((5, 1), t_val) + x_t, dx_t = path.sample(x0, x1, t) + assert x_t.shape == (5, n_dim) + assert dx_t.shape == (5, n_dim) + + @pytest.mark.parametrize("t", [0.0, 1.0, 0.5, -0.1, 1.1]) + def test_condotscheduler_call_output(self, t): + sched = CondOTScheduler() out = sched(jnp.array(t)) assert isinstance(out, tuple) assert len(out) == 4 assert all(isinstance(float(x), float) for x in out) -def test_flowmatchingmodel_sample_and_log_prob_various_shapes(): - key = jax.random.PRNGKey(2) - n_dim = 2 - mlp = get_simple_mlp(n_input=n_dim+1, n_hidden=4, n_output=n_dim, key=key) - solver = Solver(model=mlp, method=Dopri5()) - path = Path(scheduler=CondOTScheduler()) - model = FlowMatchingModel(solver=solver, path=path) - for n_samples in [1, 5, 10]: +############################## +# FlowMatchingModel Tests +############################## + +class TestFlowMatchingModel: + @pytest.fixture + def model(self): + key = jax.random.PRNGKey(42) + n_dim = 2 + n_hidden = 8 + mlp = get_simple_mlp(n_input=n_dim+1, n_hidden=n_hidden, n_output=n_dim, key=key) + solver = Solver(model=mlp, method=Dopri5()) + scheduler = CondOTScheduler() + path = Path(scheduler=scheduler) + model = FlowMatchingModel( + solver=solver, + path=path, + data_mean=jnp.zeros(n_dim), + data_cov=jnp.eye(n_dim) + ) + return model, key, n_dim + + def test_sample_and_log_prob(self, model): + model, key, n_dim = model + n_samples = 4 + samples = model.sample(key, n_samples) + assert samples.shape == (n_samples, n_dim) + assert jnp.isfinite(samples).all() + logp = eqx.filter_vmap(model.log_prob)(samples) + assert logp.shape == (n_samples, 1) + assert jnp.isfinite(logp).all() + + @pytest.mark.parametrize("n_samples", [1, 5, 10]) + def test_sample_various_shapes(self, model, n_samples): + model, key, n_dim = model samples = model.sample(key, n_samples) assert samples.shape == (n_samples, n_dim) assert jnp.isfinite(samples).all() @@ -161,54 +124,52 @@ def test_flowmatchingmodel_sample_and_log_prob_various_shapes(): assert logp.shape[0] == n_samples assert jnp.isfinite(logp).all() -def test_flowmatchingmodel_log_prob_edge_cases(): - key = jax.random.PRNGKey(3) - n_dim = 2 - mlp = get_simple_mlp(n_input=n_dim+1, n_hidden=4, n_output=n_dim, key=key) - solver = Solver(model=mlp, method=Dopri5()) - path = Path(scheduler=CondOTScheduler()) - model = FlowMatchingModel(solver=solver, path=path) - # Edge cases: zeros, large values - for arr in [jnp.zeros(n_dim), 1e6 * jnp.ones(n_dim), -1e6 * jnp.ones(n_dim)]: - logp = model.log_prob(arr) - logp_arr = jnp.asarray(logp) - assert logp_arr.size == 1 - assert jnp.isfinite(logp_arr).all() or jnp.isnan(logp_arr).all() # may be nan for extreme values - -def test_flowmatchingmodel_save_and_load(tmp_path): - key = jax.random.PRNGKey(4) - n_dim = 2 - mlp = get_simple_mlp(n_input=n_dim+1, n_hidden=4, n_output=n_dim, key=key) - solver = Solver(model=mlp, method=Dopri5()) - path_obj = Path(scheduler=CondOTScheduler()) - model = FlowMatchingModel(solver=solver, path=path_obj) - # Save and load - save_path = str(tmp_path / "test_model") - model.save_model(save_path) - loaded = model.load_model(save_path) - # Check that loaded model produces same output for same input - x = jax.random.normal(key, (2, n_dim)) - assert jnp.allclose(eqx.filter_vmap(model.log_prob)(x), eqx.filter_vmap(loaded.log_prob)(x)) - -def test_flowmatchingmodel_properties(): - key = jax.random.PRNGKey(5) - n_dim = 3 - mlp = get_simple_mlp(n_input=n_dim+1, n_hidden=4, n_output=n_dim, key=key) - solver = Solver(model=mlp, method=Dopri5()) - path = Path(scheduler=CondOTScheduler()) - mean = jnp.arange(n_dim) - cov = jnp.eye(n_dim) * 2 - model = FlowMatchingModel(solver=solver, path=path, data_mean=mean, data_cov=cov) - assert model.n_features == n_dim - assert jnp.allclose(model.data_mean, mean) - assert jnp.allclose(model.data_cov, cov) - -def test_flowmatchingmodel_print_parameters_notimplemented(): - key = jax.random.PRNGKey(6) - n_dim = 2 - mlp = get_simple_mlp(n_input=n_dim+1, n_hidden=4, n_output=n_dim, key=key) - solver = Solver(model=mlp, method=Dopri5()) - path = Path(scheduler=CondOTScheduler()) - model = FlowMatchingModel(solver=solver, path=path) - with pytest.raises(NotImplementedError): - model.print_parameters() + def test_log_prob_edge_cases(self, model): + model, key, n_dim = model + for arr in [jnp.zeros(n_dim), 1e6 * jnp.ones(n_dim), -1e6 * jnp.ones(n_dim)]: + logp = model.log_prob(arr) + logp_arr = jnp.asarray(logp) + assert logp_arr.size == 1 + assert jnp.isfinite(logp_arr).all() or jnp.isnan(logp_arr).all() # may be nan for extreme values + + def test_save_and_load(self, tmp_path, model): + model, key, n_dim = model + save_path = str(tmp_path / "test_model") + model.save_model(save_path) + loaded = model.load_model(save_path) + x = jax.random.normal(key, (2, n_dim)) + assert jnp.allclose(eqx.filter_vmap(model.log_prob)(x), eqx.filter_vmap(loaded.log_prob)(x)) + + def test_properties(self, model): + model, key, n_dim = model + mean = jnp.arange(n_dim) + cov = jnp.eye(n_dim) * 2 + model2 = FlowMatchingModel(solver=model.solver, path=model.path, data_mean=mean, data_cov=cov) + assert model2.n_features == n_dim + assert jnp.allclose(model2.data_mean, mean) + assert jnp.allclose(model2.data_cov, cov) + + def test_print_parameters_notimplemented(self, model): + model, key, n_dim = model + with pytest.raises(NotImplementedError): + model.print_parameters() + + def test_train_step_and_epoch(self, model): + model, key, n_dim = model + n_batch = 5 + n_hidden = 8 + x0 = jax.random.normal(key, (n_batch, n_dim)) + x1 = jax.random.normal(key, (n_batch, n_dim)) + t = jax.random.uniform(key, (n_batch, 1)) + optim = optax.adam(learning_rate=1e-3) + state = optim.init(eqx.filter(model, eqx.is_array)) + std = jnp.sqrt(jnp.diag(model.data_cov)) + x1_whitened = (x1 - model.data_mean) / std + x_t, dx_t = model.path.sample(x0, x1_whitened, t) + loss, model2, state2 = model.train_step(x_t, t, dx_t, optim, state) + assert jnp.isfinite(loss) + assert isinstance(model2, FlowMatchingModel) + data = (x0, x1, t) + loss_epoch, model3, state3 = model.train_epoch(key, optim, state, data, batch_size=n_batch) + assert jnp.isfinite(loss_epoch) + assert isinstance(model3, FlowMatchingModel) From 1729bdfa72d26770c8bd1dd583c1756227edce9a Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Mon, 18 Aug 2025 11:32:10 -0400 Subject: [PATCH 3/3] Format test_flowmatching.py for improved readability --- test/unit/test_flowmatching.py | 44 ++++++++++++++++++++++++---------- 1 file changed, 31 insertions(+), 13 deletions(-) diff --git a/test/unit/test_flowmatching.py b/test/unit/test_flowmatching.py index 02a1ae3c..4fa95737 100644 --- a/test/unit/test_flowmatching.py +++ b/test/unit/test_flowmatching.py @@ -13,26 +13,31 @@ import equinox as eqx import optax + def get_simple_mlp(n_input, n_hidden, n_output, key): """Simple 2-layer MLP for testing.""" - shape = [n_input] + ([n_hidden] if isinstance(n_hidden, int) else list(n_hidden)) + [n_output] - return MLP( - shape=shape, - key=key, - activation=jax.nn.swish + shape = ( + [n_input] + + ([n_hidden] if isinstance(n_hidden, int) else list(n_hidden)) + + [n_output] ) + return MLP(shape=shape, key=key, activation=jax.nn.swish) + ############################## # Solver Tests ############################## + class TestSolver: @pytest.fixture def solver(self): key = jax.random.PRNGKey(0) n_dim = 3 n_hidden = 4 - mlp = get_simple_mlp(n_input=n_dim+1, n_hidden=n_hidden, n_output=n_dim, key=key) + mlp = get_simple_mlp( + n_input=n_dim + 1, n_hidden=n_hidden, n_output=n_dim, key=key + ) return Solver(model=mlp, method=Dopri5()), key, n_dim def test_sample_shape_and_finiteness(self, solver): @@ -57,10 +62,12 @@ def test_sample_various_dt(self, solver, dt): assert samples.shape == (3, n_dim) assert jnp.isfinite(samples).all() + ############################## # Path & Scheduler Tests ############################## + class TestPathAndScheduler: def test_path_sample_shapes_and_values(self): n_dim = 2 @@ -82,17 +89,21 @@ def test_condotscheduler_call_output(self, t): assert len(out) == 4 assert all(isinstance(float(x), float) for x in out) + ############################## # FlowMatchingModel Tests ############################## + class TestFlowMatchingModel: @pytest.fixture def model(self): key = jax.random.PRNGKey(42) n_dim = 2 n_hidden = 8 - mlp = get_simple_mlp(n_input=n_dim+1, n_hidden=n_hidden, n_output=n_dim, key=key) + mlp = get_simple_mlp( + n_input=n_dim + 1, n_hidden=n_hidden, n_output=n_dim, key=key + ) solver = Solver(model=mlp, method=Dopri5()) scheduler = CondOTScheduler() path = Path(scheduler=scheduler) @@ -100,7 +111,7 @@ def model(self): solver=solver, path=path, data_mean=jnp.zeros(n_dim), - data_cov=jnp.eye(n_dim) + data_cov=jnp.eye(n_dim), ) return model, key, n_dim @@ -130,7 +141,9 @@ def test_log_prob_edge_cases(self, model): logp = model.log_prob(arr) logp_arr = jnp.asarray(logp) assert logp_arr.size == 1 - assert jnp.isfinite(logp_arr).all() or jnp.isnan(logp_arr).all() # may be nan for extreme values + assert ( + jnp.isfinite(logp_arr).all() or jnp.isnan(logp_arr).all() + ) # may be nan for extreme values def test_save_and_load(self, tmp_path, model): model, key, n_dim = model @@ -138,13 +151,17 @@ def test_save_and_load(self, tmp_path, model): model.save_model(save_path) loaded = model.load_model(save_path) x = jax.random.normal(key, (2, n_dim)) - assert jnp.allclose(eqx.filter_vmap(model.log_prob)(x), eqx.filter_vmap(loaded.log_prob)(x)) + assert jnp.allclose( + eqx.filter_vmap(model.log_prob)(x), eqx.filter_vmap(loaded.log_prob)(x) + ) def test_properties(self, model): model, key, n_dim = model mean = jnp.arange(n_dim) cov = jnp.eye(n_dim) * 2 - model2 = FlowMatchingModel(solver=model.solver, path=model.path, data_mean=mean, data_cov=cov) + model2 = FlowMatchingModel( + solver=model.solver, path=model.path, data_mean=mean, data_cov=cov + ) assert model2.n_features == n_dim assert jnp.allclose(model2.data_mean, mean) assert jnp.allclose(model2.data_cov, cov) @@ -157,7 +174,6 @@ def test_print_parameters_notimplemented(self, model): def test_train_step_and_epoch(self, model): model, key, n_dim = model n_batch = 5 - n_hidden = 8 x0 = jax.random.normal(key, (n_batch, n_dim)) x1 = jax.random.normal(key, (n_batch, n_dim)) t = jax.random.uniform(key, (n_batch, 1)) @@ -170,6 +186,8 @@ def test_train_step_and_epoch(self, model): assert jnp.isfinite(loss) assert isinstance(model2, FlowMatchingModel) data = (x0, x1, t) - loss_epoch, model3, state3 = model.train_epoch(key, optim, state, data, batch_size=n_batch) + loss_epoch, model3, state3 = model.train_epoch( + key, optim, state, data, batch_size=n_batch + ) assert jnp.isfinite(loss_epoch) assert isinstance(model3, FlowMatchingModel)