Skip to content

jvp optimisations for pseudoinverse solvers#217

Open
jpbrodrick89 wants to merge 1 commit intopatrick-kidger:mainfrom
jpbrodrick89:jpb/pseudo-jvp
Open

jvp optimisations for pseudoinverse solvers#217
jpbrodrick89 wants to merge 1 commit intopatrick-kidger:mainfrom
jpbrodrick89:jpb/pseudo-jvp

Conversation

@jpbrodrick89
Copy link
Copy Markdown
Contributor

Preamble
First of all, sincere apologies for adding another PR to the backlog. This is definitely not one that would fit in an extension package due to its invasive change of the linear_solve jvp rule, but also is not one I'd mind terribly if you deprioritise reviewing as I don't think its a critical bottleneck in my work. My hope is that submitting it now rather than waiting for the backlog to clear simply means it has more time to swim around in your head for a smoother review when we eventually get there (no pressure! 🙂). Furthermore, this could affect new solvers such as pivoted QR.

Intention of PR
I've been working on variable projection as well as thinking about improving Jax's least square JVP rule and comparing to Lineax for inspiration. I noticed we were missing some low-hanging fruit when it comes to leveraging standard optimisations for projection and Gram operators that are used in the JVP rule.

Dependent columns

The JVP rule has a term $A^\dagger A y$ which is the projection onto the row space. This has the following optimisations:

  • Wide QR: $A^†A = Q^{-\top} R^{-\top} R^\top Q^\top = \mathrm{conj}(Q) Q^\top$ by unitarity—avoiding a matmul and a triangular solve.
  • SVD: $A^†A = (U \Sigma V^\top)^†(U \Sigma V^\top) = V \Sigma U^T U \Sigma V^\top = V V^\top$ by unitarity of $U$, saving two O(mn) matmuls.

Dependent rows

The JVP rule has a term $A^\dagger A^{H\dagger} \mathrm{d}A^H(b-Ax)$, where the first two matrices can be written as the (pseudo-)inverse of the Gram matrix $A^\dagger A^{H\dagger}=(A^H A)^\dagger$. This has the following optimisations:

  • Tall QR: $(A^H A)^\dagger = (R^H Q^H Q R)^{-1} = R^{-1}R^{-H}$ saving two O(mn) matmuls.
  • SVD: $(V \Sigma U^H U \Sigma V^H)^\dagger=V \Sigma^{-2} V^H$ saving two O(mn) matmuls.
  • Tall Normal: $A^H A$ IS inner_operator so we just use a single call to the inner_solver and avoid an application of $A$.

Caveat

Savings are not always quite as good as they sound as the "vecs" are summed with others that still exist before applying the outer solve, but the saving from the inner matrix is still very real (i.e. savings are about half what they're advertised to be).

Design

I have introduced two singledispatch functions in _solve.py: _gram_inverse_mv and _row_space_projection which return NotImplemented (NOT raise NotImplementedError) by default to allow fallback to the current path and otherwise allows leveraging these optimisations in the jvp rule.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant