Skip to content

ENH: add functional jit-compatible API #9

Open
mj-will wants to merge 10 commits into
mainfrom
support-jax-jit
Open

ENH: add functional jit-compatible API #9
mj-will wants to merge 10 commits into
mainfrom
support-jax-jit

Conversation

@mj-will

@mj-will mj-will commented May 28, 2026

Copy link
Copy Markdown
Owner

Adds a functional, jit-compatible API.

This this required a rewrite of large bits of the code and all make orng and array-api-compat a hard dependency

Copilot AI left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a functional, JIT-compatible sampling API by switching the sampler/step implementation to thread explicit RNG state via orng functional backends, and by making Array API support a first-class dependency.

Changes:

  • Introduces Sampler.sample_functional(...) and refactors the sampling loop to support functional RNG state (for JAX jit workflows).
  • Refactors step implementations (PCNStep, TPCNStep) to be functional (init_state/propose/adapt) and RNG-backend-driven.
  • Updates tests, CI dependencies, and project dependencies to rely on orng + array-api-compat.

Reviewed changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
tests/test_sampler.py Adds unit tests for RNG resolution and disabling adaptation.
tests/integration_tests/test_sampling.py Adds integration coverage for functional API and JAX jit usage.
tests/conftest.py Switches RNG fixture to orng.RandomGenerator.
src/minipcn/utils.py Makes array-api-compat mandatory; adds JAX tracer handling and history stacking helpers.
src/minipcn/step.py Refactors steps to a functional API with explicit step state and RNG backend usage.
src/minipcn/sampler.py Adds sample_functional, deprecates constructor RNG, and reworks sampling loop to be functional/JIT-friendly.
README.md Updates usage docs and adds functional API section (but currently contains broken examples/installation text).
pyproject.toml Promotes array-api-compat and orng to core deps; adds jax/torch extras.
.github/workflows/test.yml Installs new extras in CI and removes standalone JAX install step.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread README.md Outdated
Comment on lines 40 to 43
## Array API support

`minipcn` also supports different array API backends via `array-api-compat`
and `orng` for random number generation. These can be installed by running
Comment thread README.md Outdated

# Run the sampler
chain, history = sampler.run(x0, n_steps=500)
chain, history = sampler.run(x0, n_steps=500, rng=rng)
Comment thread src/minipcn/sampler.py
Comment on lines 45 to 49
def __init__(
self,
log_prob_fn: Callable,
step_fn: Union[Step, str],
rng: "ArrayRNG" | np.random.Generator,
step_fn: str,
dims: int,
Comment thread src/minipcn/sampler.py Outdated
x_init: Array,
n_steps: int,
*,
rng: RandomGenerator | np.random.Generator = None,
Comment thread src/minipcn/step.py
Comment on lines +84 to +92
def __init__(
self,
dims: int,
xp: Any,
rng_backend: Any,
rho: float = 0.5,
adaptive: bool = True,
):
super().__init__(dims, xp, rng_backend)
Comment thread src/minipcn/step.py
Comment on lines +214 to +224
def init_state(self, x: Array) -> StepState:
from .utils import fit_student_t_em

self.mu, self.cov, self.nu = fit_student_t_em(x)
mu, cov, nu = fit_student_t_em(x)
if self.dims == 1:
self.inv_cov = self.xp.atleast_2d(1.0 / self.cov)
self.chol_cov = self.xp.atleast_2d(self.xp.sqrt(self.cov))
inv_cov = self.xp.atleast_2d(1.0 / cov)
chol_cov = self.xp.atleast_2d(self.xp.sqrt(cov))
else:
self.inv_cov = self.xp.linalg.inv(self.cov)
self.chol_cov = self.xp.linalg.cholesky(self.cov)
inv_cov = self.xp.linalg.inv(cov)
chol_cov = self.xp.linalg.cholesky(cov)
rho = self.xp.asarray(self.rho, dtype=x.dtype)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants