In writing #210, I discovered a bug in jax's autograd rule for pivoted qr. However, I should have never had the opportunity to observe this bug as lineax code should never need this jvp rule, instead defining its own linear_solve_p primitive with custom jvp. Here is a minimal example showing the issue
import jax
import lineax as lx
from lineax._solve import AbstractLinearSolver
key1, key2, key3 = jax.random.split(jax.random.key(0),3)
m = jax.random.normal(key1, (3,3))
mt = jax.random.normal(key2, (3,3))
v = jax.random.normal(key3, (3,))
class DisallowGradWrapper(AbstractLinearSolver):
solver: AbstractLinearSolver
def init(self, operator, options):
@jax.custom_jvp
def f(operator):
return self.solver.init(operator, options)
@f.defjvp
def _(*args):
assert False # This assertion gets hit
return f(operator)
def compute(self, state, vector, options):
return self.solver.compute(state, vector, options)
def transpose(self, state, options):
return self.solver.transpose(state, options)
def conj(self, state, options):
return self.solver.conj(state, options)
def assume_full_rank(self):
return self.solver.assume_full_rank()
f = lambda m: lx.linear_solve(lx.MatrixLinearOperator(m), v, solver=DisallowGradWrapper(lx.QR())).value
res = jax.jvp(f, (m,), (mt,))
This code asserts that solver.init() is never differentiated through, but the assertion is hit.
In writing #210, I discovered a bug in jax's autograd rule for pivoted qr. However, I should have never had the opportunity to observe this bug as lineax code should never need this jvp rule, instead defining its own
linear_solve_pprimitive with custom jvp. Here is a minimal example showing the issueThis code asserts that solver.init() is never differentiated through, but the assertion is hit.