Hi Patrick,
I wanted to experiment with some linear solvers from PETSc and I managed to do it by subclassing lx.AbstractLinearSolver and using jax.pure_callback. This ensured that all of Jax's transformations work super well.
However, this will place the computation on the host even if the external library can support GPU. Do you know if it would be possible to bypass this restriction somehow?
Hi Patrick,
I wanted to experiment with some linear solvers from PETSc and I managed to do it by subclassing
lx.AbstractLinearSolverand usingjax.pure_callback. This ensured that all of Jax's transformations work super well.However, this will place the computation on the host even if the external library can support GPU. Do you know if it would be possible to bypass this restriction somehow?