prevent linear_solve calling jvp of solver.init#212
Merged
patrick-kidger merged 1 commit intopatrick-kidger:mainfrom Mar 18, 2026
Merged
prevent linear_solve calling jvp of solver.init#212patrick-kidger merged 1 commit intopatrick-kidger:mainfrom
patrick-kidger merged 1 commit intopatrick-kidger:mainfrom
Conversation
Contributor
Author
Contributor
adconner
reviewed
Mar 9, 2026
|
|
||
| # Differentiating through operator only, but options has a dynamic array. | ||
| # solver.init should not be differentiated through. | ||
| jax.jvp(f, (m,), (mt,)) |
Contributor
There was a problem hiding this comment.
Maybe for completeness we could test the backward pass too:
_, f_vjp = jax.vjp(f, m)
f_vjp(vt)where earlier line 233 maybe
vt = jax.random.normal(getkey(), (3,))
Owner
|
@adconner's comment aside, this LGTM! |
add vjp test
ab774a7 to
d226590
Compare
Owner
|
Alright, merged! Thank you for this 🎉 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
fixes #211, previously stop_gradient was called AFTER solver.init. This meant the jvp of solver.init was still called before the return tangent being set to zero (the primal returned by the jvp is retained). If we call stop_gradient BEFORE then the jvp is never called. This means we can use primitives without jvp's (e.g. geqp3/geqrf) or with incorrect jvp primal's (e.g. qr with
pivoting=True).Note that we don't need to apply
stop_gradientto options as we do not support taking a gradient with respect to e.g. precondtioners (we just get a nondiff error).If we're happy with this I will modify my invert PR #206 to mirror the same pattern.