Skip to content

Add invert helper function to wrap linear_solve in FunctionLinearOperator#206

Open
jpbrodrick89 wants to merge 12 commits intopatrick-kidger:mainfrom
jpbrodrick89:jpb/invert
Open

Add invert helper function to wrap linear_solve in FunctionLinearOperator#206
jpbrodrick89 wants to merge 12 commits intopatrick-kidger:mainfrom
jpbrodrick89:jpb/invert

Conversation

@jpbrodrick89
Copy link
Copy Markdown
Contributor

Looking through old issues/PR's (#96 #97) it seems a real blocker to extending lineax further to block operators, Woodbury matrix identity and chained inversion is the concept of an InverseLinearOperator that allows composition (both with @ and +/-). I sketched up an implementation of this (see first commit on this branch) but quickly realised it was essentially identical to a FunctionLinearOperator with custom tag rules and the ability to cache the state through lx.linearise. This felt like overengineering to me and would give us one more operator to maintain (e.g. if we add more single dispatch functions or colouring rules). As such I instead decided to introduce a helper function that offered the same functionality. The key thing we lose is the ability to cache an InverseLinearOperator with lx.linearise and instead we have to make the decision which one we want when calling lx.invert with the cache keyword.

One idea I have in mind for using this is to introduce collapse: bool keyword argument which, when False and provided with a ComposedLinearOperator and a direct solver will return a composed chain of inverse operators for each child of the ComposedLinearOperator(in reverse order of course). This is usually more efficient (at least if we have a more efficient QR solver when ormqr is supported in jax) unless inner dimensions are larger than outer ones as it avoids the cost of pre-multiplying. We could have collapse=None to allow auto-detection of the large inner dimension case. cache=False would usually be preferred here to avoid excess memory requirements (whereas for the Woodbury matrix identity cache=True would probably be preferred).

As this is something that is not used elsewhere now, but could be used widely across the codebase in the future I appreciate that getting the design just right is important and am therefore open to making significant changes to this if you have other ideas.

Copy link
Copy Markdown
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in principle this looks pretty good to me!

I think my main (only?) concern is what will happen when we compose this into larger systems — like #196, where it turns out that we're playing whack-a-mole building our own compiler over operators. That's not really a blocker but we might want to be up-front about how far we intend to pursue such optimizations — as my preference would be 'not much', mostly for maintenance time reasons.

@jpbrodrick89
Copy link
Copy Markdown
Contributor Author

jpbrodrick89 commented Feb 27, 2026

Thanks Patrick, appreciate the concern. Always happy to have a call if easier to discuss. The main general "optimisation" in the spirit of #196 I still have in mind is to use multi_dot in ComposedLinearOperator.mv for shape optimisation only when it supports batching properly (see Jax issue 35308). As mentioned in #196 I don't propose optimising this further based on operator type (or depending on the often misleading xla floor count) or linear transposing, MAYBE based on structure (through tags) but I don't have any ideas for this right now. As such, inverse operator would act like any other when composing and its ordering would depend purely on shape (this is reasonable as once state is cached most solvers have similar scaling to matmuls), when adding to another operator lineax already does the simple usually preferred approach of just applying each operator separately. I also don't think it's possible or sensible to attempt to automatically detect independent applications of an inverse and vmap under the hood (e.g. lx.invert(A) @ C + B @ lx.invert(A)). We either just rely on XLA or do this manually in the solver as we would for the Woodbury case.

For inverse operators specifically, I think the two examples of an inverse chain and Woodbury Matrix identity are good ones to reason about to think how bad this could get (happy to submit draft PRs if you'd rather see a working example). Firstly though, I think we should limit our reasoning to direct solvers only, I don't envisage many cases where lx.invert would be sensible for iterative operators.

Let's start with Woodbury because it's a textbook case. We could either use @aidancrilly's #97 approach of providing each matrix directly or extracting from an AddLinearOperator (the key thing to note is that lx.invert doesn't FORCE us into either implementation and if we decide the latter is to complex/hard to maintain we can opt for the more explicit approach). Either way we'd end up M=A+UCV with A and C square. We will assume both A_solver, C_solver, S_solver are direct square solvers and that we cannot detect sparsity for S (so calling as_matrix is always mandatory) as I don't think it makes any sense to use an iterative solver. The code would then look something like this if cache defaults to True and assuming for simplicity that C has flat in and out structure (note I think we handle pytrees out the box):

A_inv = lx.invert(A, solver=A_solver)
A_inv_U = jax.vmap(A_inv.mv)(lx.materialise(U).pytree)
A_inv_b= A_inv.mv(b)
S = lx.invert(C, solver=C_solver) + PyTreeLinearOperator(V.mv(A_inv_U), out_structure=...)
sol = (IdentityLinearOperator(...) - A_inv_U @ lx.invert(S) @ V)(A_inv_B)

Yes I did oversimplify things and miss out on an optimisation by stacking U and b but this seems pretty elegant to me and the opposite of hard to maintain (compared to a non-operator approach at least). Just like in #196 I don't think it makes any sense to try overengineer by detecting whether it's more efficient to calculate A_inv_U or V_A_inv.

Now the more controversial approach is the inverse chain, which would probably need to be advertised as experimental to begin with. Yes I admit this does feel a bit like building a compiler over operators but I think it could be really valuable. I think I'd just take a really simple approach to begin with where each matrix in a composed chain is inverted unless the inner dimensions are bigger than the outer, no special handling of add operators to begin with (we could allow Woodbury to handle this in the future but I admit that would get convoluted and hard to maintain).

So lx.invert(A@(B+C)) -> lx.invert(B+C) @ lx.invert(A)

Also if you hate this but love the Woodbury we could just do that for now and I can play around with the chain stuff outwith lineax.

I have no deep thoughts/great insight on how to handle block matrices, and if it turns out that this is too hard to address in lineax so be it, at least we're doing a best effort in providing the foundation of lx.invert in the first place and maybe others can play around and work it out. 🙂

Another thing to make sure you're comfortable with is that we don't EXPLICITLY call solver.transpose or solver.conj. I think transpose is fine as it will call your custom transpose rule which calls these under the hood. I'm not 100% sure if there's a great efficiency loss when calling lx.conj(lx.invert(op)) but my gut feel is that it doesn't matte

Copy link
Copy Markdown
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay! Then in this case I think I'm pretty happy with this. Two outstanding minor nits (one new one + the cache one) and then we can merge this. I'm currently aiming to merge this and #164 and then do a new release.

lineax/_solve.py Outdated
Comment on lines +372 to +379
!!! Warning

Passing `state` to `linear_solve` does not support autodiff out of
the box. If you need to differentiate through the solve, either
wrap the state with `jax.lax.stop_gradient`, or use
[`lineax.invert`][] (with `cache=True`) which handles this
automatically.

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think 'state' is defined outside of the example code above. (I'm not sure I understand the warning here.)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, if you comment line 371 of tests/helpers.py out (dynamic_state = lax.stop_gradient(dynamic_state)) or the similar stop_gradient in the PR you get a bunch of failure in test_jvp_jvp*.py/two failures in test_invert.py.

RuntimeError: Unexpected tangent. `lineax.linear_solve(..., state=...)` cannot be autodifferentiated.

My point is that with lx.invert you don't need to manually add stop_gradient if wanting to re-use state. It is specifically located after the state re-use example for that reason. Do you think I should remove or reword?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also see #127 which is slightly related (but different).

Copy link
Copy Markdown
Owner

@patrick-kidger patrick-kidger Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see what you're getting at! This makes sense to me.

I'm actually not sure I really believe in our requirement that state be nondifferentiable. This was meant to enforce that the user would explicitly place a lax.stop_gradient in their own code, so that they would realise that there is no gradient there. But in retrospect I think that might be too niche / annpoying.

So as an alternative, I've just opened #213 – WDYT? (I'm also concious that will slightly impact #212.)

@jpbrodrick89
Copy link
Copy Markdown
Contributor Author

jpbrodrick89 commented Mar 12, 2026

Not sure if you're still keen on having the helper in lineax main given the new stop gradient handling, but if you are one possible extension would be to add ("has_inverse", op) tag to returned inverse operator or similar to tell lineax we already have the inverse of the inverse (the original operator) computed. Probably overly complex/overengineering for now but just thought I'd mention as it crossed my mind. 🙂

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants