Skip to content

Add differentiable (interpolated) delays and explicit max_delay#7

Open
leon-k-martin wants to merge 6 commits into
mainfrom
differentiable-delays
Open

Add differentiable (interpolated) delays and explicit max_delay#7
leon-k-martin wants to merge 6 commits into
mainfrom
differentiable-delays

Conversation

@leon-k-martin

@leon-k-martin leon-k-martin commented Jun 25, 2026

Copy link
Copy Markdown
Collaborator

Summary

Adds opt-in differentiable transmission delays, so conduction speed can be optimised by gradient descent.

  • DelayedCoupling(interpolate_delays=True) (roll strategy): reads each delayed state by linear interpolation between its two bracketing history steps instead of snapping to the nearest step. This makes the coupling differentiable in the continuous delay — hence in conduction speed (delay = tract_length / speed). Default False is byte-identical to the previous nearest-step behaviour.
  • DenseDelayGraph(max_delay=...): sizes the history buffer from an explicit static bound, decoupled from the delay values, so delays may be JAX tracers while the buffer shape stays fixed. Raises if concrete delays exceed the bound (which would silently truncate the buffer).
  • DenseDelayGraph.with_delays(delays): jit/grad-safe delay update (pytree replace, bypasses __init__/verify) — the way to vary delays = lengths / speed inside a traced loss.
  • jit-safety: history-buffer sizing in prepare/get_history now derives from the static max_delay via Python arithmetic, so the full forward + gradient path is jit-compatible.

Backward compatibility

Additive and off by default — interpolate_delays=False and max_delay=None reproduce existing behaviour exactly; no call sites change. The non-interpolated forward path is byte-identical (round/math.ceil equal the previous jnp.rint/jnp.ceil on concrete inputs).

Validation

tests/test_network_dynamics/test_interpolated_delays.py covers: forward with interpolation on/off across sizes 2/8/32; the interpolate_delays requires-roll guard; max_delay default vs explicit and the over-bound guard; with_delays; gradient vs central finite difference (incl. under jit); zero gradient for the nearest-step path; jitted conduction-speed recovery; and a heterogeneous, multi-parameter delay equation. A figures + runnable-demo comment follows below.

Notes

The delay gradient is nonzero only when the delays (and hence the interpolation fraction) are recomputed under tracing — i.e. when prepare runs inside the differentiated region.

DelayedCoupling gains an opt-in `interpolate_delays` flag (roll strategy
only): delayed states are read by linear interpolation between the two
bracketing history steps instead of snapping to the nearest step, making
the coupling differentiable w.r.t. the continuous delay (e.g. for
optimising conduction speed). The default (False) preserves the exact
nearest-step behaviour, so existing results are unchanged.

DenseDelayGraph gains an optional `max_delay` argument to size the history
buffer explicitly. This allows `delays` to be a JAX tracer (e.g.
delays = tract_length / speed) while the buffer length stays static; when
omitted, max_delay is derived from max(delays) as before.
@leon-k-martin leon-k-martin requested a review from mapi1 June 25, 2026 12:39
The history-buffer length was sized with jnp.rint/ceil, which fails under
jit (int()/round() of a traced array). Size it from the static max_delay
with Python round/math.ceil instead, so the buffer length stays static
while the per-edge delays may be JAX tracers — enabling jit-compiled
gradient optimisation of conduction speed (delays = lengths / speed).
Behaviour is unchanged for concrete inputs (round == rint, math.ceil ==
jnp.ceil).

Add DenseDelayGraph.with_delays(): update the delays via the pytree
(bypassing __init__/verify) so delays can be varied inside a jit/grad'd
loop while the buffer stays fixed by max_delay. Shape-checked.

Add tests/test_network_dynamics/test_interpolated_delays.py covering
forward (interp on/off, sizes), the constructor guard, max_delay,
with_delays, gradient-vs-finite-difference (incl. under jit), zero
gradient for the nearest-step path, jitted conduction-speed recovery,
and heterogeneous / multi-parameter delay equations.
@leon-k-martin

leon-k-martin commented Jun 26, 2026

Copy link
Copy Markdown
Collaborator Author

Differentiable interpolated delays

interpolate_delays=True reads each delayed state by linear interpolation between its two bracketing history steps, making the coupling differentiable in the continuous delay — hence in conduction speed (delay = length / speed). DenseDelayGraph(max_delay=...) fixes the history-buffer length independently of the delay values, so delays may be JAX tracers; with_delays updates them jit-safely. In every figure the target is a simulation at a known speed (3.0), recovered by gradient descent.

Loss vs conduction speed

Loss (MSE to target) swept over speed. Nearest-step (red) snaps each delay to the closest buffer step, so the curve is a staircase — flat across each step, i.e. gradient 0 almost everywhere. Interpolation (blue) is continuous, with its minimum at the true speed (dashed).

A_loss_landscape

Gradient w.r.t. speed

jax.grad of the interpolated loss (blue) lies on central finite differences (orange dots; median relative error 1.4e-6) and crosses zero at the true speed. The nearest-step gradient (red) is identically 0 — no descent direction.

C_grad_vs_fd

Conduction-speed recovery

Adam (lr 0.03) on the interpolated loss. Speed (right) moves 4.5 → 2.999 (true 3.0) with a small undershoot; loss (left, log) drops ~5 orders of magnitude.

B_optimization

Source

import jax, jax.numpy as jnp, numpy as np, optax
jax.config.update("jax_enable_x64", True)
from tvboptim.experimental.network_dynamics import Network, solve
from tvboptim.experimental.network_dynamics.graph import DenseDelayGraph
from tvboptim.experimental.network_dynamics.coupling import DelayedLinearCoupling
from tvboptim.experimental.network_dynamics.dynamics.tvb import Generic2dOscillator
from tvboptim.experimental.network_dynamics.solvers import Heun

rng = np.random.default_rng(7); N = 8
W = jnp.asarray(np.where(np.eye(N), 0.0, rng.uniform(0, 0.1, (N, N))))
L = rng.uniform(10, 80, (N, N)); L = 0.5 * (L + L.T); np.fill_diagonal(L, 0.0); L = jnp.asarray(L)

def osc():
    return Generic2dOscillator(a=-1.5, b=-15.0, d=0.015, tau=4.0, INITIAL_STATE=(0.1, 0.1))

g0 = DenseDelayGraph(W, L / 3.0, max_delay=float(jnp.max(L)) / 2.0)  # static buffer, speed >= 2

def sim(speed):
    coup = DelayedLinearCoupling(incoming_states="V", G=0.4, interpolate_delays=True)
    return solve(Network(osc(), {"delayed": coup}, g0.with_delays(L / speed)), Heun(), t1=60.0, dt=0.5).ys

target = sim(3.0)
loss = jax.jit(lambda s: jnp.mean((sim(s) - target) ** 2))
grad = jax.jit(jax.grad(loss))

v, opt = jnp.asarray(4.5), optax.adam(0.03); st = opt.init(v)
for _ in range(150):
    upd, st = opt.update(grad(v), st); v = optax.apply_updates(v, upd)
print(f"recovered speed = {float(v):.4f}")  # 2.9989

Stress: forward output and jax.grad are finite and nonzero across N ∈ {2, 8, 32, 64}, with additive noise, and at dt ∈ {0.25, 0.5, 1.0}; gradient vs central finite difference agrees to median rel. err 1.4e-6; nearest-step gradient is 0.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants