Support "diagonal" primitives with no/slow JVP batch rule#164
Support "diagonal" primitives with no/slow JVP batch rule#164jpbrodrick89 wants to merge 48 commits intopatrick-kidger:mainfrom
Conversation
merge main into fork
…ntire Jacobian matrix
|
I think I prefer this implementation (latest commit) with |
|
I finally realised why Would it be helpful to re-write the PR message with updated benchmarks and prose in light of this? |
|
That reasoning was completely wrong again, the derivative is not evaluated at the tangent vectors just multiplied by them. I think the reduced FLOPs due to matrix multiplication is only more noticeable on the example above because However, it is easy to fool def myfunc(x):
halfway = len(x) // 2
return jnp.concatenate([jnp.sin(x[:halfway]), jnp.cos(x[halfway:])])Here we seen an O(n) speed exceeding 1E4 for array sizes of 4E4 which is huge. In general I see no significant adverse impact of this and some very pronounced positive impacts in realistic use cases. |
patrick-kidger
left a comment
There was a problem hiding this comment.
Nice, I really like seeing a colouring approach like this!
I will note however that lx.diagonal is documented as Extracts the diagonal from a linear operator, and returns a vector, which is meant to include extracting the diagonal from nondiagonal operators. I think the implementations you have here should check with is_diagonal to determine whether to dispatch to the new or the old implementation.
|
Happy to address these during the week, just wanted to check: I did a quick trawl through the code and I don't think |
However they're actually a standalone public API themselves, which could be used independent of the solvers. :) |
|
Sorry for abandoning this for so long as the day job took over. I returned as it was found to provide orders of magnitude impact on a problem I was working (root finding over multiple interpolations) even for small array sizes (200–2000). I have addressed the main point of retaining extraction of the diagonal when the diagonal tag is missing. However, your other two comments seem at odds with each other: either we ensure input/output structures match (which is actually the case, see above) enabling us to extract diagonal leafs for |
|
Sorry for the long delay getting back to you, some personal life things took over for a while. So, now to actually answer your question: good point. I imagine we could probably do the 'diagonal leaf' approach when the structures match, and go for the more expensive approach when they don't? |
…full_rank (patrick-kidger#158) The two functions allow_dependent_{rows,columns} together did the job of answering if the solver accepts full rank matrices for the purposes of the jvp. Allowing them to be implemented separately created some issues: 1) Invalid states were representable. Eg. What does it mean that dependent columns are allowed for square matrices if dependent rows are not? What does it mean that dependent rows are not allowed for matrices with more rows than columns? 2) As the functions accept operator as input, a custom solver could in principle decide its answer based on operator's dynamic value rather than only jax compilation static information regarding it, as in all the lineax defined solvers. This would prevent jax compilation and jit. Both issues are addressed by asking the solver to report only if it assumes the input is numerically full rank. If this assumption is exactly violated, its behavior is allowed to be undefined, and is allowed to error, produce NaN values, and produce invalid values.
|
No worries, hope you're managing alright. Sorry to miss you at DiffSys, but looking forward to catching up with Johanna! Just to check I understand correctly, shall we go forward with the status quo that Diagonal (as for Symmetric) operators must ALWAYS have their input and output structures and continue to raise ValueError's when this is violated? Therefore, we do not need to touch JacobianLinearOperator but just need to adopt the diagonal leaf approach for PyTreeLinearOperators and we're done? |
* Add sparse materialisation helper and efficient diagonal paths This PR introduces _try_sparse_materialise helper and optimizes diagonal operator handling throughout lineax. Key changes: - Add _try_sparse_materialise() that converts diagonal-tagged operators to DiagonalLinearOperator, preserving pytree structure via unravel - Add efficient diagonal() for JLO/FLO using single JVP/VJP with ones basis - Add efficient diagonal() for Composed: diag(A @ B) = diag(A) * diag(B) - Simplify mv() for MLO, PTLO, Add, Composed to use _try_sparse_materialise - Apply early sparse materialisation in materialise() registrations Aux handling: - Fix bug: linearise/materialise now preserve aux on AuxLinearOperator - Preserve aux from first operator in Composed (output comes from op1) - Inner aux in Add children silently stripped (unclear semantics - may warrant guards in future) --------- Co-authored-by: jpbrodrick89 <jpbrodrick89@users.noreply.github.com>
|
Note this is largely ready to go and includes #195 (but not #196), happy to get those in first and merge this in after for a cleaner diff. Main change I've made is the Also happy to be targeting a dev branch if you prefer as these are not just "fixes", one just doesn't exist right now. |
|
Great now completely ready to go! 😃 (this is completely independent of #198) Happy to hear any final thoughts and appreciate that helper functions might benefit from moving around etc. |
patrick-kidger
left a comment
There was a problem hiding this comment.
Okay indeed final thoughts – I think this all looks pretty good to me! Minor comments below only.
Whilst I think this is fine and I am happy to merge this as-is, I am wondering if the many special-cases are indicative of the need for some other structure here. (E.g. maybe we should arrange for materialise to always start with _try_sparse_materialise, and then fall back to per-operator overloads?) In general I'm not sure what that is though.
I don't think I'd like to see too much more complexity here beyond what we have now, though, e.g. I'm a bit antsy about what the analogous tridigonal implementations will look like layered on top of this.
lineax/_operator.py
Outdated
| if is_diagonal(operator): | ||
| return diagonal(operator.operator1) * diagonal(operator.operator2) |
There was a problem hiding this comment.
I don't think a diagonal composed operator is necessarily formed from two diagonal operators.
There was a problem hiding this comment.
See lines 2100-2109 of _operator.py:
# These properties ARE preserved under composition
for check in (
is_diagonal,
is_lower_triangular,
is_upper_triangular,
):
@check.register(ComposedLinearOperator)
def _(operator, check=check):
return check(operator.operator1) and check(operator.operator2)There was a problem hiding this comment.
Oh wait you mean the converse might not be true, yes good catch theoretically you're right. But right now in lineax that's not true as we don't look for tags on the composed linear operator. So if it's not made of two diagonal operator it won't REGISTER as being diagonal so it's fine for now unless you want to change that behaviour.
There was a problem hiding this comment.
I've hardcoded the explicit is_diagonal(operator.operator1) and is_diagonal(operator.operator2) check here too for forward compatibility in case we change it in the future.
The only two ways I can think of to do it is either (1) a full on breaking change or (2) make pyright very unhappy when users define materialise registrations for custom operators.
I really don't think the tridiagonal will be that bad with the current structure. I'll try neaten it up tomorrow. I'd really appreciate if we can get that in the release as that is the killer use case that would make a lot of difference to me. |
|
Tests still pass locally I'm wondering if its due to RNG (i.e. random vec is sometimes in range of singular matrix and sometimes not)? |



BREAKING CHANGE: Diagonals are no longer "extracted" from operators but rely on the promise of the
diagonaltag being accurate and fulfilled. If this promise is broken solutions may differ from user expectations.Preface
At present, both
JacobianLinearOperatorandFunctionLinearOperatorrequire full materialisation even if provided with adiagonaltag. This seems self-evidently expensive (in practice it certainly can be but more often is not, see below) and requires the underlying function (which could potentially be a custom primitive) to have a batching rule. As it is currently the case that tags are considered to be a "promise" and are unchecked with no guarantee of behaviour, there are some shortcuts we can take.Changes made
The proposal here is to use the observation that the diagonal of a matrix can be obtained by pre/post-multiplying it by a unit vector and thereby re-write the single-dispatch
diagonalmethod forJacobianLinearOperatorandFunctionLinearOperatorso that theas_matrix()method is not required. ForJacobianLinearOperatoreitherjax.jvporjax.vjpwill be called depending on thejackeyword (forward-mode should always be more efficient but meeting the user's expectation will avoid issues if forward-mode is not supported such as when using acustom_vjpis used). ForFunctionLinearOperator, we can just useself.mv.However, if the matrix is not actually diagonal this identity will not hold and results may be unexpected due to contributions from off-diagonals.
I considered using
operator.transpose().mvinstead of writing outvjpbut if the matrix is tagged assymmetricthen this would end up callingjacrevinstead ofvjp.Why is this helpful?
When using
lineaxdirectly one can of course just define aDiagonalOperatorinstead of a more generalJacobianLinearOperator, but this is not always possible. For example, when usingoptimistix, the operator is instantiated within the optimisation routine and the only way to inform the optimiser about the underlying structure of the matrix is throughtags. Therefore, if the function being optimised is a primitive (e.g. an FFI) with a JVP rule that does not support batching a user is stuck. If a slow batching rule, such asvmap_method="sequential", is used the current approach is also painfully slow for large matrix sizes.Performance impact
I had initially hoped this to have a minor positive impact on performance across the board, but as ever I have massively underestimated the power of XLA. In practice, whether this PR seems to improve performance (e.g. for a
linear_solveor anoptimistix.root_find) of a pure jax function appears to fluctuate with array size. By playing around with differentXLA_FLAGSand other environment variables, my best guess is that this is mostly due to threading; avmapapplied to ajnp.eyeis threaded much more aggressively meaning that the apparent time complexity appears to be of lower order than a more direct approach. However, when I tried to eliminate threading this PR still seems to have an 8–10% negative impact on performance for array sizes > 100 on anoptimistix.root_find.Pure `jax` comparison: using `jvp` when attempting to enfore single-threadedness is about 14µs faster.
It seems self-evident that the second function should be more efficient, however with the new

thunkruntime on my Macfrom_eyeruns faster thandirect(referred to aswrappedin the diagram below, you can ignoreunwrappedandvmapas similar performance) for array sizes > ~1.5E4:Disabling the

thunkruntime (withXLA_FLAGS=--xla_cpu_use_thunk_runtime=falsewhich is reported to run faster in some circumstances) decreases the gap between the two by slightly slowing down theeyeimplementation and accelerating thedirectapproach:Going further and following all suggestions in github.com/jax-ml/jax/discussions/22739 to limit to one thread/core and we can see the

directapproach is now consistently about 14µs faster:`linear_solve` significantly faster (often >2x) for array sizes <2E4 using thunk runtime, but runs about 6–10% slower for large array sizes when disabling and attempting to enforce a single thread
Code tested
Using standard thunk runtime and

EQX_ON_ERROR=nanwe see significant speedup for array sizes < 1E4Enforcing single-thread the performance between the old and the new approaches is very similar but tracks at about 6–10% slower for larger array sizes.

(Note that
DiagonalOperatoris actually slower somehow.)Similar behaviour is observed with `optimistix.root_find` (but with more modest gains, and some hits for larger array sizes)
I compared performance for a multi-root find of the
sinfunction (withEQX_ON_ERROR=nan):Default settings (

jax0.6.1,mainvs this branch oflineax) with and without standardthunk` runtime:In both runtimes this PR improves/maintains performance by up to a factor of 2 for arrays of size up to 1E4 at which point it becomes slightly slow than the current version (by ~8%).
However, limiting to one thread as best as I can most of the noise is eliminated and the two have very similar performance time (the change tracking about 6% slower) except for an array size of 20 where the proposed change is faster:

Much more substantial performance improvement (8x or higher) is observed for primitives that only support `sequential` batching rules
This is a very contrived example, but based on very real use cases we have over at tesseract-core and tesseract-jax. I have defined a new primitive version of
sinwith ajvprule that batches sequentially and is therefore slow and doesn't benefit from compilation/threading in the same way:Code for primitives
I then ran the same tests as before but with

sin_pinstead ofjnp.sinand we can see the time complexity of the current version is almost quadratic for array sizes greater than 100 (as one would naively expect for a dense jacobian) meaning that speedups range from a factor of 2 (array size of 20) to a factor of 8 (array size of 5000) and higher:Running(This usesbenchmarks/solver_speed.pyshows a negligible improvement in the singleDiagonalsolve but a 50% faster batch solve, this could of course be down to noise as the solve is only timed once.lx.Diagonalso not relevant and probably just a fluke.)Testing done
test_diagonalsuch that operators are actually initialised with diagonal matricesNewtonandBisection) of scalar function with no batching rule and take gradients through the root solve (not possible previously) this tests bothJacobianLinearOperatorandFunctionLinearOperatorin actiondiagonalfromJacobianLinearOperatorwithjac="bwd"Happy to perform any further requested testing you see fit/necessary. I appreciate I haven't managed to test reverse-mode especially extensively.
Next steps
In a future PR, I would like to do something similar for other structures (e.g. tridiagonal) this should address the large O(n) discrepancy observed in #149 (but not the O(0) discrepancy). I believe this will be a much more consistent and meaningful gain than observed here. This PR here should likely be a lot easier to grok and reason about the concept and discuss framework/design choices (although maybe not the performance impact :) ) before building out further.