Skip to content

Differentiation of linear_solve causes differentiation through factorization (solver state) #211

@adconner

Description

@adconner

In writing #210, I discovered a bug in jax's autograd rule for pivoted qr. However, I should have never had the opportunity to observe this bug as lineax code should never need this jvp rule, instead defining its own linear_solve_p primitive with custom jvp. Here is a minimal example showing the issue

import jax
import lineax as lx
from lineax._solve import AbstractLinearSolver

key1, key2, key3 = jax.random.split(jax.random.key(0),3)
m = jax.random.normal(key1, (3,3))
mt = jax.random.normal(key2, (3,3))
v = jax.random.normal(key3, (3,))

class DisallowGradWrapper(AbstractLinearSolver):
    solver: AbstractLinearSolver
    def init(self, operator, options):
        @jax.custom_jvp
        def f(operator):
            return self.solver.init(operator, options)
        @f.defjvp
        def _(*args):
            assert False # This assertion gets hit
        return f(operator)
    def compute(self, state, vector, options):
        return self.solver.compute(state, vector, options)
    def transpose(self, state, options):
        return self.solver.transpose(state, options)
    def conj(self, state, options):
        return self.solver.conj(state, options)
    def assume_full_rank(self):
        return self.solver.assume_full_rank()

f = lambda m: lx.linear_solve(lx.MatrixLinearOperator(m), v, solver=DisallowGradWrapper(lx.QR())).value
res = jax.jvp(f, (m,), (mt,))

This code asserts that solver.init() is never differentiated through, but the assertion is hit.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions