Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 133 additions & 0 deletions src/aspire/flows/jax/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,26 @@


class FlowJax(Flow):
"""Flow implementation using FlowJax.

Parameters
----------
dims: int
Dimensionality of the flow
key: jax.random.KeyArray, optional
Random key for the flow. If None, a random key will be generated with
seed 0.
data_transform: BaseTransform, optional
Transform to apply to the data before fitting the flow. If None, no
transform will be applied.
dtype: jax.numpy.dtype, optional
Data type for the flow parameters and computations. If None, defaults to
jnp.float32.
**kwargs:
Additional keyword arguments to pass to the flow constructor. See the
flowjax documentation for details.
"""

xp = jnp

def __init__(
Expand Down Expand Up @@ -58,6 +78,20 @@ def __init__(
)

def fit(self, x, **kwargs):
"""Fit the flow to data.

This method calls :code:`fit_to_data` from flowjax.
See the flowjax documentation for details on the available fitting
options: https://danielward27.github.io/flowjax/api/training.html

Parameters
----------
x: array-like
Data to fit the flow to.
**kwargs: dict
Additional keyword arguments to pass to :code:`fit_to_data`. See
the flowjax documentation for details.
"""
from ...history import FlowHistory

x = jnp.asarray(x, dtype=self.dtype)
Expand All @@ -70,6 +104,22 @@ def fit(self, x, **kwargs):
)

def forward(self, x, xp: Callable = jnp):
"""Apply the flow transformation to the samples.

Parameters
----------
x: array-like
Samples to transform
xp: Callable, optional
Array library to use for the output. Default is jax.numpy.

Returns
-------
z: array-like
Transformed samples
log_abs_det_jacobian: array-like
Log absolute determinant of the Jacobian of the transformation
"""
x = jnp.asarray(x, dtype=self.dtype)
x_prime, log_abs_det_jacobian = self.rescale(x)
x_prime = jnp.asarray(x_prime, dtype=self.dtype)
Expand All @@ -79,6 +129,22 @@ def forward(self, x, xp: Callable = jnp):
)

def inverse(self, z, xp: Callable = jnp):
"""Apply the inverse flow transformation to the samples.

Parameters
----------
z: array-like
Samples to transform
xp: Callable, optional
Array library to use for the output. Default is jax.numpy.

Returns
-------
x: array-like
Transformed samples
log_abs_det_jacobian: array-like
Log absolute determinant of the Jacobian of the transformation
"""
z = jnp.asarray(z, dtype=self.dtype)
x_prime, log_abs_det_jacobian_flow = self._flow.inverse(z)
x_prime = jnp.asarray(x_prime, dtype=self.dtype)
Expand All @@ -88,26 +154,79 @@ def inverse(self, z, xp: Callable = jnp):
)

def log_prob(self, x, xp: Callable = jnp):
"""Compute the log probability of the samples.

Parameters
----------
x: array-like
Samples to compute the log probability for
xp: Callable, optional
Array library to use for the output. Default is jax.numpy.

Returns
-------
log_prob: array-like
Log probability of the samples
"""
x = jnp.asarray(x, dtype=self.dtype)
x_prime, log_abs_det_jacobian = self.rescale(x)
x_prime = jnp.asarray(x_prime, dtype=self.dtype)
log_prob = self._flow.log_prob(x_prime)
return xp.asarray(log_prob + log_abs_det_jacobian)

def sample(self, n_samples: int, xp: Callable = jnp):
"""Generate samples from the flow.

Parameters
----------
n_samples: int
Number of samples to generate
xp: Callable, optional
Array library to use for the output. Default is jax.numpy.

Returns
-------
x: array-like
Generated samples
"""
self.key, subkey = jrandom.split(self.key)
x_prime = self._flow.sample(subkey, (n_samples,))
x = self.inverse_rescale(x_prime)[0]
return xp.asarray(x)

def sample_and_log_prob(self, n_samples: int, xp: Callable = jnp):
"""Generate samples from the flow and compute their log probability.

Parameters ----------
n_samples: int
Number of samples to generate
xp: Callable, optional
Array library to use for the output. Default is jax.numpy.

Returns
-------
x: array-like
Generated samples
log_prob: array-like
Log probability of the generated samples
"""
self.key, subkey = jrandom.split(self.key)
x_prime = self._flow.sample(subkey, (n_samples,))
log_prob = self._flow.log_prob(x_prime)
x, log_abs_det_jacobian = self.inverse_rescale(x_prime)
return xp.asarray(x), xp.asarray(log_prob - log_abs_det_jacobian)

def save(self, h5_file, path="flow"):
"""Save the flow to an HDF5 file.

Parameters
----------
h5_file: h5py.File
The HDF5 file to save to. The file should be opened in a mode that
allows writing.
path: str, optional
The path within the HDF5 file to save to. Default is "flow".
"""
import equinox as eqx
from array_api_compat import numpy as np

Expand Down Expand Up @@ -146,6 +265,20 @@ def save(self, h5_file, path="flow"):

@classmethod
def load(cls, h5_file, path="flow"):
"""Load a flow from an HDF5 file.

Parameters
----------
h5_file: h5py.File
The HDF5 file to load from.
path: str, optional
The path within the HDF5 file to load from. Default is "flow".

Returns
-------
FlowJax
The loaded flow object.
"""
import equinox as eqx

from ...utils import load_from_h5_file
Expand Down
Loading
Loading