Skip to content

GW-JAX-Team/JAXNRSur

 
 

JAXNRSur

A JAX-based package for differentiable numerical relativity surrogate waveform generation

docs license coverage pre-commit.ci status

JAXNRSur is a JAX-based package for differentiable evaluation of numerical relativity surrogate waveforms. By reimplementing the surrogate pipeline in JAX, it delivers NR-faithful time-domain waveforms for precessing and high-mass-ratio binary black holes with native GPU support and automatic differentiation — enabling gradient-based inference within modern gravitational-wave pipelines such as Jim.

Supported models:

  • NRHybSur3dq8 — aligned-spin hybrid surrogate (q ≤ 8)
  • NRSur7dq4 — precessing NR surrogate (q ≤ 4)

For a quick introduction, see the Quick Start guide.

Installation

The simplest way to install JAXNRSur is through pip:

pip install JAXNRSur

This will install the latest stable release and its dependencies. JAXNRSur is built on JAX. By default, this installs the CPU version of JAX. If you have an NVIDIA GPU, install the CUDA-enabled version:

pip install JAXNRSur[cuda]

If you want to install the latest version of JAXNRSur, you can clone this repo and install it locally:

git clone https://github.com/GW-JAX-Team/JAXNRSur.git
cd JAXNRSur
pip install -e .

We recommend using uv to manage your Python environment. After cloning the repository, run uv sync to create a virtual environment with all dependencies installed.

Attribution

If you use JAXNRSur in your research, please cite the underlying surrogate models:

About

Numerical relativity surrogate waveforms in JAX

Resources

License

Code of conduct

Contributing

Stars

Watchers

Forks

Contributors

Languages

  • Python 100.0%