diff --git a/src/aspire/flows/jax/flows.py b/src/aspire/flows/jax/flows.py index 4973e7b..a081301 100644 --- a/src/aspire/flows/jax/flows.py +++ b/src/aspire/flows/jax/flows.py @@ -15,6 +15,26 @@ class FlowJax(Flow): + """Flow implementation using FlowJax. + + Parameters + ---------- + dims: int + Dimensionality of the flow + key: jax.random.KeyArray, optional + Random key for the flow. If None, a random key will be generated with + seed 0. + data_transform: BaseTransform, optional + Transform to apply to the data before fitting the flow. If None, no + transform will be applied. + dtype: jax.numpy.dtype, optional + Data type for the flow parameters and computations. If None, defaults to + jnp.float32. + **kwargs: + Additional keyword arguments to pass to the flow constructor. See the + flowjax documentation for details. + """ + xp = jnp def __init__( @@ -58,6 +78,20 @@ def __init__( ) def fit(self, x, **kwargs): + """Fit the flow to data. + + This method calls :code:`fit_to_data` from flowjax. + See the flowjax documentation for details on the available fitting + options: https://danielward27.github.io/flowjax/api/training.html + + Parameters + ---------- + x: array-like + Data to fit the flow to. + **kwargs: dict + Additional keyword arguments to pass to :code:`fit_to_data`. See + the flowjax documentation for details. + """ from ...history import FlowHistory x = jnp.asarray(x, dtype=self.dtype) @@ -70,6 +104,22 @@ def fit(self, x, **kwargs): ) def forward(self, x, xp: Callable = jnp): + """Apply the flow transformation to the samples. + + Parameters + ---------- + x: array-like + Samples to transform + xp: Callable, optional + Array library to use for the output. Default is jax.numpy. + + Returns + ------- + z: array-like + Transformed samples + log_abs_det_jacobian: array-like + Log absolute determinant of the Jacobian of the transformation + """ x = jnp.asarray(x, dtype=self.dtype) x_prime, log_abs_det_jacobian = self.rescale(x) x_prime = jnp.asarray(x_prime, dtype=self.dtype) @@ -79,6 +129,22 @@ def forward(self, x, xp: Callable = jnp): ) def inverse(self, z, xp: Callable = jnp): + """Apply the inverse flow transformation to the samples. + + Parameters + ---------- + z: array-like + Samples to transform + xp: Callable, optional + Array library to use for the output. Default is jax.numpy. + + Returns + ------- + x: array-like + Transformed samples + log_abs_det_jacobian: array-like + Log absolute determinant of the Jacobian of the transformation + """ z = jnp.asarray(z, dtype=self.dtype) x_prime, log_abs_det_jacobian_flow = self._flow.inverse(z) x_prime = jnp.asarray(x_prime, dtype=self.dtype) @@ -88,6 +154,20 @@ def inverse(self, z, xp: Callable = jnp): ) def log_prob(self, x, xp: Callable = jnp): + """Compute the log probability of the samples. + + Parameters + ---------- + x: array-like + Samples to compute the log probability for + xp: Callable, optional + Array library to use for the output. Default is jax.numpy. + + Returns + ------- + log_prob: array-like + Log probability of the samples + """ x = jnp.asarray(x, dtype=self.dtype) x_prime, log_abs_det_jacobian = self.rescale(x) x_prime = jnp.asarray(x_prime, dtype=self.dtype) @@ -95,12 +175,41 @@ def log_prob(self, x, xp: Callable = jnp): return xp.asarray(log_prob + log_abs_det_jacobian) def sample(self, n_samples: int, xp: Callable = jnp): + """Generate samples from the flow. + + Parameters + ---------- + n_samples: int + Number of samples to generate + xp: Callable, optional + Array library to use for the output. Default is jax.numpy. + + Returns + ------- + x: array-like + Generated samples + """ self.key, subkey = jrandom.split(self.key) x_prime = self._flow.sample(subkey, (n_samples,)) x = self.inverse_rescale(x_prime)[0] return xp.asarray(x) def sample_and_log_prob(self, n_samples: int, xp: Callable = jnp): + """Generate samples from the flow and compute their log probability. + + Parameters ---------- + n_samples: int + Number of samples to generate + xp: Callable, optional + Array library to use for the output. Default is jax.numpy. + + Returns + ------- + x: array-like + Generated samples + log_prob: array-like + Log probability of the generated samples + """ self.key, subkey = jrandom.split(self.key) x_prime = self._flow.sample(subkey, (n_samples,)) log_prob = self._flow.log_prob(x_prime) @@ -108,6 +217,16 @@ def sample_and_log_prob(self, n_samples: int, xp: Callable = jnp): return xp.asarray(x), xp.asarray(log_prob - log_abs_det_jacobian) def save(self, h5_file, path="flow"): + """Save the flow to an HDF5 file. + + Parameters + ---------- + h5_file: h5py.File + The HDF5 file to save to. The file should be opened in a mode that + allows writing. + path: str, optional + The path within the HDF5 file to save to. Default is "flow". + """ import equinox as eqx from array_api_compat import numpy as np @@ -146,6 +265,20 @@ def save(self, h5_file, path="flow"): @classmethod def load(cls, h5_file, path="flow"): + """Load a flow from an HDF5 file. + + Parameters + ---------- + h5_file: h5py.File + The HDF5 file to load from. + path: str, optional + The path within the HDF5 file to load from. Default is "flow". + + Returns + ------- + FlowJax + The loaded flow object. + """ import equinox as eqx from ...utils import load_from_h5_file diff --git a/src/aspire/flows/torch/flows.py b/src/aspire/flows/torch/flows.py index f8aee48..12c5025 100644 --- a/src/aspire/flows/torch/flows.py +++ b/src/aspire/flows/torch/flows.py @@ -111,6 +111,29 @@ def load(self, h5_file, path="flow"): class ZukoFlow(BaseTorchFlow): + """Flow wrapper for flows from the Zuko library + + Parameters + ---------- + dims: int + Dimensionality of the data + flow_class: str or Callable + The flow class to use. Can be a string (e.g. "MAF", "CNF") or a + callable that returns an instance of a flow. + data_transform: BaseTransform, optional + A transform to apply to the data before fitting the flow. If None, no + transform is applied. + seed: int + Random seed for initializing the flow + device: str + Device to run the flow on (e.g. "cpu", "cuda") + dtype: torch.dtype, optional + Data type for the flow parameters. If None, uses the default torch + dtype. + **kwargs: + Additional keyword arguments to pass to the flow constructor. + """ + def __init__( self, dims, @@ -153,7 +176,30 @@ def fit( validation_fraction: float = 0.2, clip_grad: float | None = None, lr_annealing: bool = False, + patience: int | None = None, ): + """Fit the flow to samples using maximum likelihood estimation (forward + KL divergence). + + Parameters + ---------- + x: array-like + Samples to fi the flow to + n_epochs: int + Number of epochs to train for + lr: float + Learning rate for the optimizer + batch_size: int + Batch size for training + validation_fraction: float + Fraction of the data to use for validation + clip_grad: float | None + If not None, clip gradients to this value + lr_annealing: bool + Whether to use cosine annealing for the learning rate + patience: int | None + If not None, use early stopping with this patience (in epochs) + """ from ...history import FlowHistory if not is_torch_array(x): @@ -225,9 +271,10 @@ def fit( best_val_loss = float("inf") best_flow_state = None + best_epoch = 0 with tqdm.tqdm(range(n_epochs), desc="Epochs") as pbar: - for _ in pbar: + for epoch in pbar: self.flow.train() loss_epoch = 0.0 for (x_batch,) in dataset: @@ -253,32 +300,86 @@ def fit( if avg_val_loss < best_val_loss: best_val_loss = avg_val_loss best_flow_state = copy.deepcopy(self.flow.state_dict()) + best_epoch = epoch history.validation_loss.append(avg_val_loss) pbar.set_postfix( train_loss=f"{avg_train_loss:.4f}", val_loss=f"{avg_val_loss:.4f}", ) + + if patience is not None and epoch - best_epoch >= patience: + logger.info( + f"Early stopping triggered after {patience} epochs" + ) + break + if best_flow_state is not None: self.flow.load_state_dict(best_flow_state) - logger.info(f"Loaded best model with val loss {best_val_loss:.4f}") + logger.info( + f"Loaded best model from epoch {best_epoch} " + f"with val loss {best_val_loss:.4f}" + ) self.flow.eval() return history def sample_and_log_prob(self, n_samples: int, xp=torch_api): + """Generate samples from the flow and compute their log probability. + + Parameters ---------- + n_samples: int + Number of samples to generate + xp: Callable, optional + Array library to use for the output. Default is jax.numpy. + + Returns + ------- + x: array-like + Generated samples + log_prob: array-like + Log probability of the generated samples + """ with torch.no_grad(): x_prime, log_prob = self.flow().rsample_and_log_prob((n_samples,)) x, log_abs_det_jacobian = self.inverse_rescale(x_prime) return xp.asarray(x), xp.asarray(log_prob - log_abs_det_jacobian) def sample(self, n_samples: int, xp=torch_api): + """Generate samples from the flow. + + Parameters + ---------- + n_samples: int + Number of samples to generate + xp: Callable, optional + Array library to use for the output. Default is torch. + + Returns + ------- + x: array-like + Generated samples + """ with torch.no_grad(): x_prime = self.flow().rsample((n_samples,)) x = self.inverse_rescale(x_prime)[0] return xp.asarray(x) def log_prob(self, x, xp=torch_api): + """Compute the log probability of the samples. + + Parameters + ---------- + x: array-like + Samples to compute the log probability for + xp: Callable, optional + Array library to use for the output. Default is torch. + + Returns + ------- + log_prob: array-like + Log probability of the samples + """ x = torch.as_tensor(x, dtype=self.dtype, device=self.device) x_prime, log_abs_det_jacobian = self.rescale(x) return xp.asarray( @@ -286,6 +387,22 @@ def log_prob(self, x, xp=torch_api): ) def forward(self, x, xp=torch_api): + """Apply the flow transformation to the samples. + + Parameters + ---------- + x: array-like + Samples to transform + xp: Callable, optional + Array library to use for the output. Default is torch. + + Returns + ------- + z: array-like + Transformed samples + log_abs_det_jacobian: array-like + Log absolute determinant of the Jacobian of the transformation + """ x = torch.as_tensor(x, dtype=self.dtype, device=self.device) x_prime, log_j_rescale = self.rescale(x) z, log_abs_det_jacobian = self._flow().transform.call_and_ladj(x_prime) @@ -297,6 +414,22 @@ def forward(self, x, xp=torch_api): return xp.asarray(z), xp.asarray(log_abs_det_jacobian + log_j_rescale) def inverse(self, z, xp=torch_api): + """Apply the inverse flow transformation to the samples. + + Parameters + ---------- + z: array-like + Samples to transform + xp: Callable, optional + Array library to use for the output. Default is torch. + + Returns + ------- + x: array-like + Transformed samples + log_abs_det_jacobian: array-like + Log absolute determinant of the Jacobian of the transformation + """ z = torch.as_tensor(z, dtype=self.dtype, device=self.device) with torch.no_grad(): x_prime, log_abs_det_jacobian = ( @@ -312,6 +445,12 @@ def inverse(self, z, xp=torch_api): class ZukoFlowMatching(ZukoFlow): + """Flow wrapper for training Zuko flows using flow matching. + + Note that this flow is only compatible with CNF flows, as flow matching is + not implemented for discrete flows in Zuko. + """ + def __init__( self, dims,