Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion lineax/_solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,7 +771,11 @@ def linear_solve(
stats={},
)
if state == sentinel:
state = solver.init(operator, options)
dynamic_operator, static_operator = eqx.partition(operator, eqx.is_array)
stopped_operator = eqx.combine(
lax.stop_gradient(dynamic_operator), static_operator
)
state = solver.init(stopped_operator, options)

dynamic_state, static_state = eqx.partition(state, eqx.is_array)
dynamic_state = lax.stop_gradient(dynamic_state)
Expand Down
7 changes: 3 additions & 4 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,10 +381,9 @@ def jvp_jvp_impl(
if use_state:

def linear_solve1(operator, vector):
state = solver.init(operator, options={})
state_dynamic, state_static = eqx.partition(state, eqx.is_inexact_array)
state_dynamic = lax.stop_gradient(state_dynamic)
state = eqx.combine(state_dynamic, state_static)
op_dynamic, op_static = eqx.partition(operator, eqx.is_inexact_array)
stopped_operator = eqx.combine(lax.stop_gradient(op_dynamic), op_static)
state = solver.init(stopped_operator, options={})

sol = lx.linear_solve(operator, vector, state=state, solver=solver)
return sol.value
Expand Down
52 changes: 52 additions & 0 deletions tests/test_solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,58 @@ def test_iterative_solver_max_steps_only(solver):
lx.linear_solve(poisson_operator, rhs, solver)


def test_solver_init_not_differentiated(getkey):
"""stop_gradient should be applied before solver.init, not after.

Also checks that dynamic arrays in options don't cause issues.
"""

class DisallowGradWrapper(lx._solve.AbstractLinearSolver):
solver: lx._solve.AbstractLinearSolver

def init(self, operator, options):
@jax.custom_jvp
def f(operator, dummy):
del dummy
return self.solver.init(operator, options)

@f.defjvp
def _(*args):
raise NotImplementedError("solver.init should not be differentiated")

return f(operator, options.get("dummy"))

def compute(self, state, vector, options):
return self.solver.compute(state, vector, options)

def transpose(self, state, options):
return self.solver.transpose(state, options)

def conj(self, state, options):
return self.solver.conj(state, options)

def assume_full_rank(self):
return self.solver.assume_full_rank()

m = jax.random.normal(getkey(), (3, 3))
mt = jax.random.normal(getkey(), (3, 3))
v = jax.random.normal(getkey(), (3,))
dummy = jnp.array(1.0)

def f(m):
op = lx.MatrixLinearOperator(m)
return lx.linear_solve(
op, v, solver=DisallowGradWrapper(lx.QR()), options={"dummy": dummy}
).value

# Differentiating through operator only, but options has a dynamic array.
# solver.init should not be differentiated through.
jax.jvp(f, (m,), (mt,))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe for completeness we could test the backward pass too:

_, f_vjp = jax.vjp(f, m)
f_vjp(vt)

where earlier line 233 maybe

vt = jax.random.normal(getkey(), (3,))


_, f_vjp = jax.vjp(f, m)
f_vjp(v)


def test_nonfinite_input():
operator = lx.DiagonalLinearOperator((1.0, 1.0))
vector = (1.0, jnp.inf)
Expand Down
7 changes: 3 additions & 4 deletions tests/test_vmap_jvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,9 @@ def test_vmap_jvp(
if use_state:

def linear_solve1(operator, vector):
state = solver.init(operator, options={})
state_dynamic, state_static = eqx.partition(state, eqx.is_inexact_array)
state_dynamic = lax.stop_gradient(state_dynamic)
state = eqx.combine(state_dynamic, state_static)
op_dynamic, op_static = eqx.partition(operator, eqx.is_inexact_array)
stopped_operator = eqx.combine(lax.stop_gradient(op_dynamic), op_static)
state = solver.init(stopped_operator, options={})

return lx.linear_solve(operator, vector, state=state, solver=solver)

Expand Down
Loading