Add invert helper function to wrap linear_solve in FunctionLinearOperator#206
Add invert helper function to wrap linear_solve in FunctionLinearOperator#206jpbrodrick89 wants to merge 12 commits intopatrick-kidger:mainfrom
Conversation
patrick-kidger
left a comment
There was a problem hiding this comment.
I think in principle this looks pretty good to me!
I think my main (only?) concern is what will happen when we compose this into larger systems — like #196, where it turns out that we're playing whack-a-mole building our own compiler over operators. That's not really a blocker but we might want to be up-front about how far we intend to pursue such optimizations — as my preference would be 'not much', mostly for maintenance time reasons.
|
Thanks Patrick, appreciate the concern. Always happy to have a call if easier to discuss. The main general "optimisation" in the spirit of #196 I still have in mind is to use For inverse operators specifically, I think the two examples of an inverse chain and Woodbury Matrix identity are good ones to reason about to think how bad this could get (happy to submit draft PRs if you'd rather see a working example). Firstly though, I think we should limit our reasoning to direct solvers only, I don't envisage many cases where Let's start with Woodbury because it's a textbook case. We could either use @aidancrilly's #97 approach of providing each matrix directly or extracting from an A_inv = lx.invert(A, solver=A_solver)
A_inv_U = jax.vmap(A_inv.mv)(lx.materialise(U).pytree)
A_inv_b= A_inv.mv(b)
S = lx.invert(C, solver=C_solver) + PyTreeLinearOperator(V.mv(A_inv_U), out_structure=...)
sol = (IdentityLinearOperator(...) - A_inv_U @ lx.invert(S) @ V)(A_inv_B)Yes I did oversimplify things and miss out on an optimisation by stacking U and b but this seems pretty elegant to me and the opposite of hard to maintain (compared to a non-operator approach at least). Just like in #196 I don't think it makes any sense to try overengineer by detecting whether it's more efficient to calculate Now the more controversial approach is the inverse chain, which would probably need to be advertised as experimental to begin with. Yes I admit this does feel a bit like building a compiler over operators but I think it could be really valuable. I think I'd just take a really simple approach to begin with where each matrix in a composed chain is inverted unless the inner dimensions are bigger than the outer, no special handling of add operators to begin with (we could allow Woodbury to handle this in the future but I admit that would get convoluted and hard to maintain). So Also if you hate this but love the Woodbury we could just do that for now and I can play around with the chain stuff outwith lineax. I have no deep thoughts/great insight on how to handle block matrices, and if it turns out that this is too hard to address in lineax so be it, at least we're doing a best effort in providing the foundation of lx.invert in the first place and maybe others can play around and work it out. 🙂 Another thing to make sure you're comfortable with is that we don't EXPLICITLY call |
patrick-kidger
left a comment
There was a problem hiding this comment.
Okay! Then in this case I think I'm pretty happy with this. Two outstanding minor nits (one new one + the cache one) and then we can merge this. I'm currently aiming to merge this and #164 and then do a new release.
lineax/_solve.py
Outdated
| !!! Warning | ||
|
|
||
| Passing `state` to `linear_solve` does not support autodiff out of | ||
| the box. If you need to differentiate through the solve, either | ||
| wrap the state with `jax.lax.stop_gradient`, or use | ||
| [`lineax.invert`][] (with `cache=True`) which handles this | ||
| automatically. | ||
|
|
There was a problem hiding this comment.
I don't think 'state' is defined outside of the example code above. (I'm not sure I understand the warning here.)
There was a problem hiding this comment.
For example, if you comment line 371 of tests/helpers.py out (dynamic_state = lax.stop_gradient(dynamic_state)) or the similar stop_gradient in the PR you get a bunch of failure in test_jvp_jvp*.py/two failures in test_invert.py.
RuntimeError: Unexpected tangent. `lineax.linear_solve(..., state=...)` cannot be autodifferentiated.
My point is that with lx.invert you don't need to manually add stop_gradient if wanting to re-use state. It is specifically located after the state re-use example for that reason. Do you think I should remove or reword?
There was a problem hiding this comment.
Also see #127 which is slightly related (but different).
There was a problem hiding this comment.
Ah, I see what you're getting at! This makes sense to me.
I'm actually not sure I really believe in our requirement that state be nondifferentiable. This was meant to enforce that the user would explicitly place a lax.stop_gradient in their own code, so that they would realise that there is no gradient there. But in retrospect I think that might be too niche / annpoying.
So as an alternative, I've just opened #213 – WDYT? (I'm also concious that will slightly impact #212.)
|
Not sure if you're still keen on having the helper in lineax main given the new stop gradient handling, but if you are one possible extension would be to add |
Looking through old issues/PR's (#96 #97) it seems a real blocker to extending lineax further to block operators, Woodbury matrix identity and chained inversion is the concept of an
InverseLinearOperatorthat allows composition (both with@and+/-). I sketched up an implementation of this (see first commit on this branch) but quickly realised it was essentially identical to aFunctionLinearOperatorwith custom tag rules and the ability to cache the state throughlx.linearise. This felt like overengineering to me and would give us one more operator to maintain (e.g. if we add more single dispatch functions or colouring rules). As such I instead decided to introduce a helper function that offered the same functionality. The key thing we lose is the ability to cache anInverseLinearOperatorwithlx.lineariseand instead we have to make the decision which one we want when callinglx.invertwith thecachekeyword.One idea I have in mind for using this is to introduce
collapse: boolkeyword argument which, whenFalseand provided with aComposedLinearOperatorand a direct solver will return a composed chain of inverse operators for each child of theComposedLinearOperator(in reverse order of course). This is usually more efficient (at least if we have a more efficient QR solver whenormqris supported in jax) unless inner dimensions are larger than outer ones as it avoids the cost of pre-multiplying. We could havecollapse=Noneto allow auto-detection of the large inner dimension case.cache=Falsewould usually be preferred here to avoid excess memory requirements (whereas for the Woodbury matrix identitycache=Truewould probably be preferred).As this is something that is not used elsewhere now, but could be used widely across the codebase in the future I appreciate that getting the design just right is important and am therefore open to making significant changes to this if you have other ideas.