From 826b2bd8e3fa0e921825311b710e34db668f5ada Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Sun, 17 Aug 2025 12:18:13 +0000 Subject: [PATCH 1/2] Add whitening procedure to FlowMatchingModel --- .../resource/model/flowmatching/base.py | 51 ++++++++++++++++--- 1 file changed, 45 insertions(+), 6 deletions(-) diff --git a/src/flowMC/resource/model/flowmatching/base.py b/src/flowMC/resource/model/flowmatching/base.py index 896f77a0..fd7607e9 100644 --- a/src/flowMC/resource/model/flowmatching/base.py +++ b/src/flowMC/resource/model/flowmatching/base.py @@ -5,6 +5,7 @@ from flowMC.resource.base import Resource from flowMC.resource.model.common import MLP from typing_extensions import Self +from typing import Optional import jax.numpy as jnp import jax from jax.scipy.stats.multivariate_normal import logpdf @@ -110,18 +111,53 @@ class FlowMatchingModel(eqx.Module, Resource): solver: Solver path: Path + _data_mean: Float[Array, " n_dim"] + _data_cov: Float[Array, " n_dim n_dim"] - def __init__(self, solver: Solver, path: Path): + @property + def n_features(self): + return self.solver.model.n_input - 1 + + @property + def data_mean(self): + return jax.lax.stop_gradient(self._data_mean) + + @property + def data_cov(self): + return jax.lax.stop_gradient(jnp.atleast_2d(self._data_cov)) + + def __init__( + self, + solver: Solver, + path: Path, + data_mean: Optional[Float[Array, " n_dim"]] = None, + data_cov: Optional[Float[Array, " n_dim n_dim"]] = None, + ): self.solver = solver self.path = path + n_features = self.n_features + if data_mean is not None: + self._data_mean = data_mean + else: + self._data_mean = jnp.zeros(n_features) + + if data_cov is not None: + self._data_cov = data_cov + else: + self._data_cov = jnp.eye(n_features) def sample(self, rng_key: PRNGKeyArray, num_samples: int, dt: Float = 1e-1) -> Float[Array, " n_dim"]: rng_key, subkey = jax.random.split(rng_key) samples = self.solver.sample(subkey, num_samples, dt=dt) + std = jnp.sqrt(jnp.diag(self.data_cov)) + samples = samples * std + self.data_mean return samples def log_prob(self, x: Float[Array, " n_dim"]) -> Float: - return self.solver.log_prob(x) + std = jnp.sqrt(jnp.diag(self.data_cov)) + x_whitened = (x - self.data_mean) / std + log_det = -jnp.sum(jnp.log(std)) + return self.solver.log_prob(x_whitened) + log_det def save_model(self, path: str): eqx.tree_serialise_leaves(path + ".eqx", self) @@ -160,8 +196,9 @@ def train_epoch( """Train for a single epoch.""" value = 1e9 model = self - train_ds_size = len(data) + train_ds_size = len(data[0]) steps_per_epoch = train_ds_size // batch_size + std = jnp.sqrt(jnp.diag(self.data_cov)) if steps_per_epoch > 0: perms = jax.random.permutation(rng, train_ds_size) @@ -169,12 +206,14 @@ def train_epoch( perms = perms.reshape((steps_per_epoch, batch_size)) for perm in perms: batch_x0, batch_x1, batch_t = data[0][perm, ...], data[1][perm, ...], data[2][perm, ...] + batch_x1 = (batch_x1 - self.data_mean) / std batch_x_t, batch_dx_t = self.path.sample(batch_x0, batch_x1, batch_t) value, model, state = model.train_step( batch_x_t, batch_t, batch_dx_t, optim, state ) else: - x_t, dx_t = self.path.sample(data[0], data[1], data[2]) + batch_x1 = (data[1] - self.data_mean) / std + x_t, dx_t = self.path.sample(data[0], batch_x1, data[2]) value, model, state = model.train_step( x_t, data[2], dx_t, optim, state ) @@ -214,8 +253,8 @@ def train( best_model = model = self best_state = state best_loss = 1e9 - # model = eqx.tree_at(lambda m: m._data_mean, model, jnp.mean(data, axis=0)) - # model = eqx.tree_at(lambda m: m._data_cov, model, jnp.cov(data.T)) + model = eqx.tree_at(lambda m: m._data_mean, model, jnp.mean(data[1], axis=0)) + model = eqx.tree_at(lambda m: m._data_cov, model, jnp.cov(data[1].T)) for epoch in pbar: # Use a separate PRNG key to permute image data during shuffling rng, input_rng = jax.random.split(rng) From 211d1f91a29f030442c1817fa091fff057463617 Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Sun, 17 Aug 2025 08:26:52 -0400 Subject: [PATCH 2/2] Refactor flow matching code for readability and add diffrax to pre-commit config --- .pre-commit-config.yaml | 2 +- docs/tutorials/train_flow_match.ipynb | 34 ++++-- src/flowMC/resource/model/common.py | 1 + .../resource/model/flowmatching/base.py | 103 ++++++++++++------ src/flowMC/strategy/sequential_monte_carlo.py | 7 +- 5 files changed, 100 insertions(+), 47 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 14e2809a..f9ef6417 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,7 +12,7 @@ repos: rev: v1.1.396 hooks: - id: pyright - additional_dependencies: [beartype, einops, jax, jaxtyping, pytest, typing_extensions, equinox, optax, tqdm] + additional_dependencies: [beartype, einops, jax, jaxtyping, pytest, typing_extensions, equinox, optax, tqdm, diffrax] - repo: https://github.com/nbQA-dev/nbQA rev: 1.9.1 hooks: diff --git a/docs/tutorials/train_flow_match.ipynb b/docs/tutorials/train_flow_match.ipynb index 2befb793..fa126d6b 100644 --- a/docs/tutorials/train_flow_match.ipynb +++ b/docs/tutorials/train_flow_match.ipynb @@ -13,9 +13,14 @@ "import equinox as eqx # Equinox\n", "\n", "from flowMC.resource.model.common import MLP\n", - "from flowMC.resource.model.flowmatching.base import Solver, Path, CondOTScheduler, FlowMatchingModel\n", + "from flowMC.resource.model.flowmatching.base import (\n", + " Solver,\n", + " Path,\n", + " CondOTScheduler,\n", + " FlowMatchingModel,\n", + ")\n", "from sklearn.datasets import make_moons\n", - "import matplotlib.pyplot as plt\n" + "import matplotlib.pyplot as plt" ] }, { @@ -41,7 +46,9 @@ "metadata": {}, "outputs": [], "source": [ - "solver = Solver(MLP([3,128, 128, 128,2], jax.random.PRNGKey(0), activation=jax.nn.swish))\n", + "solver = Solver(\n", + " MLP([3, 128, 128, 128, 2], jax.random.PRNGKey(0), activation=jax.nn.swish)\n", + ")\n", "path = Path(CondOTScheduler())\n", "model = FlowMatchingModel(solver, path)" ] @@ -109,13 +116,14 @@ } ], "source": [ - "\n", "key = jax.random.PRNGKey(seed)\n", "\n", "key, subkey = jax.random.split(key)\n", "x0 = jax.random.normal(subkey, (data.shape[0], 2)) # Initial points\n", "key, subkey = jax.random.split(key)\n", - "t = jax.random.uniform(subkey, (data.shape[0], 1), minval=0.0, maxval=1.0) # Random time points\n", + "t = jax.random.uniform(\n", + " subkey, (data.shape[0], 1), minval=0.0, maxval=1.0\n", + ") # Random time points\n", "\n", "optim = optax.adam(learning_rate)\n", "state = optim.init(eqx.filter(model, eqx.is_array))\n", @@ -144,7 +152,9 @@ ], "source": [ "sampled_data = trained_model.sample(key, 10000, dt=0.1)\n", - "plt.scatter(sampled_data[:, 0], sampled_data[:, 1], s=0.5, alpha=0.5, label=\"sampled data\")\n", + "plt.scatter(\n", + " sampled_data[:, 0], sampled_data[:, 1], s=0.5, alpha=0.5, label=\"sampled data\"\n", + ")\n", "plt.legend()\n", "plt.show()" ] @@ -170,11 +180,13 @@ "grid = jnp.mgrid[-2:2:100j, -2:2:100j]\n", "grid = grid.reshape(2, -1).T # Reshape to (10000, 2)\n", "log_prob = eqx.filter_vmap(trained_model.log_prob, in_axes=(0,))(grid)\n", - "plt.imshow(log_prob.reshape(100, 100).T, extent=(-2, 2, -2, 2), origin='lower', cmap='viridis')\n", - "plt.colorbar(label='Log Probability')\n", - "plt.title('Log Probability Density')\n", - "plt.xlabel('x')\n", - "plt.ylabel('y')\n", + "plt.imshow(\n", + " log_prob.reshape(100, 100).T, extent=(-2, 2, -2, 2), origin=\"lower\", cmap=\"viridis\"\n", + ")\n", + "plt.colorbar(label=\"Log Probability\")\n", + "plt.title(\"Log Probability Density\")\n", + "plt.xlabel(\"x\")\n", + "plt.ylabel(\"y\")\n", "plt.show()" ] } diff --git a/src/flowMC/resource/model/common.py b/src/flowMC/resource/model/common.py index 414bbcb3..8de695d3 100644 --- a/src/flowMC/resource/model/common.py +++ b/src/flowMC/resource/model/common.py @@ -6,6 +6,7 @@ from jaxtyping import Array, Float, PRNGKeyArray from abc import abstractmethod + class Bijection(eqx.Module): """Base class for bijective transformations. diff --git a/src/flowMC/resource/model/flowmatching/base.py b/src/flowMC/resource/model/flowmatching/base.py index fd7607e9..00620b37 100644 --- a/src/flowMC/resource/model/flowmatching/base.py +++ b/src/flowMC/resource/model/flowmatching/base.py @@ -1,6 +1,5 @@ import equinox as eqx from jaxtyping import PRNGKeyArray, Float, Array, PyTree -from numpy import diff import optax from flowMC.resource.base import Resource from flowMC.resource.model.common import MLP @@ -12,27 +11,34 @@ from diffrax import diffeqsolve, ODETerm, Dopri5, AbstractSolver from tqdm import trange, tqdm + class Solver(eqx.Module): - model: MLP # Shape should be [input_dim + t_dim, hiddens, output_dim] + model: MLP # Shape should be [input_dim + t_dim, hiddens, output_dim] method: AbstractSolver def __init__(self, model: MLP, method: str = "dopri5"): self.model = model self.method = Dopri5() - def sample(self, rng_key: PRNGKeyArray, n_samples: int, dt: Float = 1e-1) -> Float[Array, "n_samples n_dim"]: + def sample( + self, rng_key: PRNGKeyArray, n_samples: int, dt: Float = 1e-1 + ) -> Float[Array, "n_samples n_dims"]: """Sample points from the solver. This sovles the ODE forward, i.e. from the prior to the posterior. """ - def model_wrapper(t: Float, x: Float[Array, "n_dims"], args: PyTree) -> Float[Array, "n_dim"]: + def model_wrapper( + t: Float, x: Float[Array, " n_dims"], args: PyTree + ) -> Float[Array, " n_dims"]: """Wrapper for the model to be used in the ODE solver.""" t = jnp.expand_dims(t, axis=-1) x = jnp.concatenate([x, t], axis=-1) return self.model(x) - def solve_ode(y0: Float[Array, " n_dims"], dt: Float = 1e-1) -> Float[Array, " n_dims"]: + def solve_ode( + y0: Float[Array, " n_dims"], dt: Float = 1e-1 + ) -> Float[Array, " n_dims"]: """Solve the ODE with initial condition y0.""" term = ODETerm(model_wrapper) sol = diffeqsolve( @@ -43,26 +49,31 @@ def solve_ode(y0: Float[Array, " n_dims"], dt: Float = 1e-1) -> Float[Array, " n dt0=dt, y0=y0, ) - return sol.ys[-1] # type: ignore - - x0 = jax.random.normal(rng_key, (n_samples, self.model.n_input-1)) + return sol.ys[-1] # type: ignore + + x0 = jax.random.normal(rng_key, (n_samples, self.model.n_input - 1)) sols = eqx.filter_vmap(solve_ode, in_axes=(0, None))(x0, dt) return sols - def log_prob(self, x1: Float[Array, " n_dims"], dt: Float = 1e-1) -> Float: """Compute the log probability of the initial condition x1. This solves the ODE backward, i.e. from the posterior to the prior. """ - def model_wrapper(t: Float, x: Float[Array, "n_dims"], args: PyTree) -> tuple[Float[Array, "n_dim"], Float[Array, "1"]]: - """Wrapper for the model to be used in the ODE solver.""" + + def model_wrapper( + t: Float, x: Float[Array, " n_dims"], args: PyTree + ) -> list[Float[Array, " ,,,"]]: + """Wrapper for the model to be used in the ODE solver. + + The output shape should be [n_dims, 1]. + """ t = jnp.expand_dims(t, axis=-1) x = jnp.concatenate([x[0], t], axis=-1) y = self.model(x) div = jax.jacrev(self.model, argnums=0)(x)[:, :-1] return [y, jnp.trace(div)] - - def solve_ode(y0: Float[Array, " n_dims"], dt: Float = 1e-1) -> Float[Array, " n_dims"]: + + def solve_ode(y0: Float[Array, " n_dims"], dt: Float = 1e-1) -> PyTree: """Solve the ODE with initial condition y0.""" term = ODETerm(model_wrapper) y_init = jax.tree.map(jnp.asarray, [y0, 0.0]) @@ -75,9 +86,17 @@ def solve_ode(y0: Float[Array, " n_dims"], dt: Float = 1e-1) -> Float[Array, " n y0=y_init, ) return sol.ys - + x0, log_p = solve_ode(x1, dt) - return logpdf(x1, mean=self.model.n_output * jnp.zeros(self.model.n_output), cov=jnp.eye(self.model.n_output)) + log_p + return ( + logpdf( + x1, + mean=self.model.n_output * jnp.zeros(self.model.n_output), + cov=jnp.eye(self.model.n_output), + ) + + log_p + ) + class Scheduler: @@ -85,28 +104,31 @@ def __call__(self, t: Float) -> tuple[Float, Float, Float, Float]: """Return the parameters of the scheduler at time t.""" raise NotImplementedError + class CondOTScheduler(Scheduler): """Conditional Optimal Transport Scheduler.""" def __call__(self, t: Float) -> tuple[Float, Float, Float, Float]: """Return the parameters of the scheduler at time t.""" # Implement the logic to compute alpha_t, d_alpha_t, sigma_t, d_sigma_t - return t, 1., 1. - t, -1. + return t, 1.0, 1.0 - t, -1.0 + class Path: scheduler: Scheduler - + def __init__(self, scheduler: Scheduler): self.scheduler = scheduler - def sample(self, x0: Float, x1: Float, t:Float) -> Float: + def sample(self, x0: Float, x1: Float, t: Float) -> Float: """Sample a point along the path between x0 and x1 at time t.""" alpha_t, d_alpha_t, sigma_t, d_sigma_t = self.scheduler(t) x_t = sigma_t * x0 + alpha_t * x1 dx_t = d_sigma_t * x0 + d_alpha_t * x1 return x_t, dx_t + class FlowMatchingModel(eqx.Module, Resource): solver: Solver @@ -146,7 +168,9 @@ def __init__( else: self._data_cov = jnp.eye(n_features) - def sample(self, rng_key: PRNGKeyArray, num_samples: int, dt: Float = 1e-1) -> Float[Array, " n_dim"]: + def sample( + self, rng_key: PRNGKeyArray, num_samples: int, dt: Float = 1e-1 + ) -> Float[Array, " n_dim"]: rng_key, subkey = jax.random.split(rng_key) samples = self.solver.sample(subkey, num_samples, dt=dt) std = jnp.sqrt(jnp.diag(self.data_cov)) @@ -166,9 +190,16 @@ def load_model(self, path: str) -> Self: return eqx.tree_deserialise_leaves(path + ".eqx", self) @eqx.filter_value_and_grad - def loss_fn(self, x: Float[Array, "n_batch n_dim"], t:Float[Array, "n_batch 1"], dx_t: Float[Array, "n_batch n_dim"],) -> Float: + def loss_fn( + self, + x: Float[Array, "n_batch n_dim"], + t: Float[Array, "n_batch 1"], + dx_t: Float[Array, "n_batch n_dim"], + ) -> Float: x = jnp.concatenate([x, t], axis=-1) - return jnp.mean((eqx.filter_vmap(self.solver.model, in_axes=(0))(x) - dx_t) ** 2) + return jnp.mean( + (eqx.filter_vmap(self.solver.model, in_axes=(0))(x) - dx_t) ** 2 + ) @eqx.filter_jit def train_step( @@ -190,7 +221,11 @@ def train_epoch( rng: PRNGKeyArray, optim: optax.GradientTransformation, state: optax.OptState, - data: tuple[Float[Array, "n_example n_dim"], Float[Array, "n_example n_dim"], Float[Array, "n_example 1"]], + data: tuple[ + Float[Array, "n_example n_dim"], + Float[Array, "n_example n_dim"], + Float[Array, "n_example 1"], + ], batch_size: Float, ) -> tuple[Float, Self, optax.OptState]: """Train for a single epoch.""" @@ -205,7 +240,11 @@ def train_epoch( perms = perms[: steps_per_epoch * batch_size] # skip incomplete batch perms = perms.reshape((steps_per_epoch, batch_size)) for perm in perms: - batch_x0, batch_x1, batch_t = data[0][perm, ...], data[1][perm, ...], data[2][perm, ...] + batch_x0, batch_x1, batch_t = ( + data[0][perm, ...], + data[1][perm, ...], + data[2][perm, ...], + ) batch_x1 = (batch_x1 - self.data_mean) / std batch_x_t, batch_dx_t = self.path.sample(batch_x0, batch_x1, batch_t) value, model, state = model.train_step( @@ -214,15 +253,17 @@ def train_epoch( else: batch_x1 = (data[1] - self.data_mean) / std x_t, dx_t = self.path.sample(data[0], batch_x1, data[2]) - value, model, state = model.train_step( - x_t, data[2], dx_t, optim, state - ) + value, model, state = model.train_step(x_t, data[2], dx_t, optim, state) return value, model, state def train( self: Self, rng: PRNGKeyArray, - data: tuple[Float[Array, "n_example n_dim"], Float[Array, "n_example n_dim"], Float[Array, "n_example 1"]], + data: tuple[ + Float[Array, "n_example n_dim"], + Float[Array, "n_example n_dim"], + Float[Array, "n_example 1"], + ], optim: optax.GradientTransformation, state: optax.OptState, num_epochs: int, @@ -277,11 +318,11 @@ def train( pbar.set_description(f"Training NF, current loss: {value:.3f}") return rng, best_model, best_state, loss_values - + save_resource = save_model load_resource = load_model def print_parameters(self): - raise NotImplementedError("print_parameters is not implemented for FlowMatchingModel") - - \ No newline at end of file + raise NotImplementedError( + "print_parameters is not implemented for FlowMatchingModel" + ) diff --git a/src/flowMC/strategy/sequential_monte_carlo.py b/src/flowMC/strategy/sequential_monte_carlo.py index 00a88104..6691af1d 100644 --- a/src/flowMC/strategy/sequential_monte_carlo.py +++ b/src/flowMC/strategy/sequential_monte_carlo.py @@ -1,12 +1,11 @@ -from flowMC.strategy.base import Strategy from flowMC.resource.base import Resource from jaxtyping import Array, Float, PRNGKeyArray -from typing import Callable + class SequentialMonteCarlo(Resource): def __init__(self): raise NotImplementedError - + def __call__( self, rng_key: PRNGKeyArray, @@ -18,4 +17,4 @@ def __call__( dict[str, Resource], Float[Array, "n_chains n_dim"], ]: - raise NotImplementedError \ No newline at end of file + raise NotImplementedError