Skip to content

prevent linear_solve calling jvp of solver.init#212

Merged
patrick-kidger merged 1 commit intopatrick-kidger:mainfrom
jpbrodrick89:jpb/fix-stop-gradient
Mar 18, 2026
Merged

prevent linear_solve calling jvp of solver.init#212
patrick-kidger merged 1 commit intopatrick-kidger:mainfrom
jpbrodrick89:jpb/fix-stop-gradient

Conversation

@jpbrodrick89
Copy link
Copy Markdown
Contributor

@jpbrodrick89 jpbrodrick89 commented Mar 9, 2026

fixes #211, previously stop_gradient was called AFTER solver.init. This meant the jvp of solver.init was still called before the return tangent being set to zero (the primal returned by the jvp is retained). If we call stop_gradient BEFORE then the jvp is never called. This means we can use primitives without jvp's (e.g. geqp3/geqrf) or with incorrect jvp primal's (e.g. qr with pivoting=True).

Note that we don't need to apply stop_gradient to options as we do not support taking a gradient with respect to e.g. precondtioners (we just get a nondiff error).

If we're happy with this I will modify my invert PR #206 to mirror the same pattern.

@jpbrodrick89
Copy link
Copy Markdown
Contributor Author

@adconner lmk if this gets your #210 to pass tests.

@adconner adconner mentioned this pull request Mar 9, 2026
@adconner
Copy link
Copy Markdown
Contributor

adconner commented Mar 9, 2026

@adconner lmk if this gets your #210 to pass tests.

Yes fixed!


# Differentiating through operator only, but options has a dynamic array.
# solver.init should not be differentiated through.
jax.jvp(f, (m,), (mt,))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe for completeness we could test the backward pass too:

_, f_vjp = jax.vjp(f, m)
f_vjp(vt)

where earlier line 233 maybe

vt = jax.random.normal(getkey(), (3,))

@patrick-kidger
Copy link
Copy Markdown
Owner

@adconner's comment aside, this LGTM!

@patrick-kidger patrick-kidger force-pushed the jpb/fix-stop-gradient branch from ab774a7 to d226590 Compare March 17, 2026 23:30
@patrick-kidger patrick-kidger merged commit b6f3087 into patrick-kidger:main Mar 18, 2026
1 check passed
@patrick-kidger
Copy link
Copy Markdown
Owner

Alright, merged! Thank you for this 🎉

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Differentiation of linear_solve causes differentiation through factorization (solver state)

3 participants