There are currently no unit tests for gradient_creator() and I don't trust the implementation for JAX -- it calls jax.grad() on a ParametrizedFunction, not the internal JAX function.
It would also be nice to illustrate the creation of a gradient for a function in a notebook. (E.g., a simple demo of error propagation with autodiff.)
There are currently no unit tests for
gradient_creator()and I don't trust the implementation for JAX -- it callsjax.grad()on aParametrizedFunction, not the internal JAX function.It would also be nice to illustrate the creation of a gradient for a function in a notebook. (E.g., a simple demo of error propagation with autodiff.)