ENH: add functional jit-compatible API #9
Open
mj-will wants to merge 10 commits into
Open
Conversation
There was a problem hiding this comment.
Pull request overview
Adds a functional, JIT-compatible sampling API by switching the sampler/step implementation to thread explicit RNG state via orng functional backends, and by making Array API support a first-class dependency.
Changes:
- Introduces
Sampler.sample_functional(...)and refactors the sampling loop to support functional RNG state (for JAXjitworkflows). - Refactors step implementations (
PCNStep,TPCNStep) to be functional (init_state/propose/adapt) and RNG-backend-driven. - Updates tests, CI dependencies, and project dependencies to rely on
orng+array-api-compat.
Reviewed changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/test_sampler.py | Adds unit tests for RNG resolution and disabling adaptation. |
| tests/integration_tests/test_sampling.py | Adds integration coverage for functional API and JAX jit usage. |
| tests/conftest.py | Switches RNG fixture to orng.RandomGenerator. |
| src/minipcn/utils.py | Makes array-api-compat mandatory; adds JAX tracer handling and history stacking helpers. |
| src/minipcn/step.py | Refactors steps to a functional API with explicit step state and RNG backend usage. |
| src/minipcn/sampler.py | Adds sample_functional, deprecates constructor RNG, and reworks sampling loop to be functional/JIT-friendly. |
| README.md | Updates usage docs and adds functional API section (but currently contains broken examples/installation text). |
| pyproject.toml | Promotes array-api-compat and orng to core deps; adds jax/torch extras. |
| .github/workflows/test.yml | Installs new extras in CI and removes standalone JAX install step. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Comment on lines
40
to
43
| ## Array API support | ||
|
|
||
| `minipcn` also supports different array API backends via `array-api-compat` | ||
| and `orng` for random number generation. These can be installed by running |
|
|
||
| # Run the sampler | ||
| chain, history = sampler.run(x0, n_steps=500) | ||
| chain, history = sampler.run(x0, n_steps=500, rng=rng) |
Comment on lines
45
to
49
| def __init__( | ||
| self, | ||
| log_prob_fn: Callable, | ||
| step_fn: Union[Step, str], | ||
| rng: "ArrayRNG" | np.random.Generator, | ||
| step_fn: str, | ||
| dims: int, |
| x_init: Array, | ||
| n_steps: int, | ||
| *, | ||
| rng: RandomGenerator | np.random.Generator = None, |
Comment on lines
+84
to
+92
| def __init__( | ||
| self, | ||
| dims: int, | ||
| xp: Any, | ||
| rng_backend: Any, | ||
| rho: float = 0.5, | ||
| adaptive: bool = True, | ||
| ): | ||
| super().__init__(dims, xp, rng_backend) |
Comment on lines
+214
to
+224
| def init_state(self, x: Array) -> StepState: | ||
| from .utils import fit_student_t_em | ||
|
|
||
| self.mu, self.cov, self.nu = fit_student_t_em(x) | ||
| mu, cov, nu = fit_student_t_em(x) | ||
| if self.dims == 1: | ||
| self.inv_cov = self.xp.atleast_2d(1.0 / self.cov) | ||
| self.chol_cov = self.xp.atleast_2d(self.xp.sqrt(self.cov)) | ||
| inv_cov = self.xp.atleast_2d(1.0 / cov) | ||
| chol_cov = self.xp.atleast_2d(self.xp.sqrt(cov)) | ||
| else: | ||
| self.inv_cov = self.xp.linalg.inv(self.cov) | ||
| self.chol_cov = self.xp.linalg.cholesky(self.cov) | ||
| inv_cov = self.xp.linalg.inv(cov) | ||
| chol_cov = self.xp.linalg.cholesky(cov) | ||
| rho = self.xp.asarray(self.rho, dtype=x.dtype) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Adds a functional, jit-compatible API.
This this required a rewrite of large bits of the code and all make orng and array-api-compat a hard dependency