Add differentiable (interpolated) delays and explicit max_delay#7
Open
leon-k-martin wants to merge 6 commits into
Open
Add differentiable (interpolated) delays and explicit max_delay#7leon-k-martin wants to merge 6 commits into
leon-k-martin wants to merge 6 commits into
Conversation
DelayedCoupling gains an opt-in `interpolate_delays` flag (roll strategy only): delayed states are read by linear interpolation between the two bracketing history steps instead of snapping to the nearest step, making the coupling differentiable w.r.t. the continuous delay (e.g. for optimising conduction speed). The default (False) preserves the exact nearest-step behaviour, so existing results are unchanged. DenseDelayGraph gains an optional `max_delay` argument to size the history buffer explicitly. This allows `delays` to be a JAX tracer (e.g. delays = tract_length / speed) while the buffer length stays static; when omitted, max_delay is derived from max(delays) as before.
The history-buffer length was sized with jnp.rint/ceil, which fails under jit (int()/round() of a traced array). Size it from the static max_delay with Python round/math.ceil instead, so the buffer length stays static while the per-edge delays may be JAX tracers — enabling jit-compiled gradient optimisation of conduction speed (delays = lengths / speed). Behaviour is unchanged for concrete inputs (round == rint, math.ceil == jnp.ceil). Add DenseDelayGraph.with_delays(): update the delays via the pytree (bypassing __init__/verify) so delays can be varied inside a jit/grad'd loop while the buffer stays fixed by max_delay. Shape-checked. Add tests/test_network_dynamics/test_interpolated_delays.py covering forward (interp on/off, sizes), the constructor guard, max_delay, with_delays, gradient-vs-finite-difference (incl. under jit), zero gradient for the nearest-step path, jitted conduction-speed recovery, and heterogeneous / multi-parameter delay equations.
…entiable delays tests
Collaborator
Author
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.



Summary
Adds opt-in differentiable transmission delays, so conduction speed can be optimised by gradient descent.
DelayedCoupling(interpolate_delays=True)(rollstrategy): reads each delayed state by linear interpolation between its two bracketing history steps instead of snapping to the nearest step. This makes the coupling differentiable in the continuous delay — hence in conduction speed (delay = tract_length / speed). DefaultFalseis byte-identical to the previous nearest-step behaviour.DenseDelayGraph(max_delay=...): sizes the history buffer from an explicit static bound, decoupled from the delay values, sodelaysmay be JAX tracers while the buffer shape stays fixed. Raises if concrete delays exceed the bound (which would silently truncate the buffer).DenseDelayGraph.with_delays(delays): jit/grad-safe delay update (pytree replace, bypasses__init__/verify) — the way to varydelays = lengths / speedinside a traced loss.prepare/get_historynow derives from the staticmax_delayvia Python arithmetic, so the full forward + gradient path is jit-compatible.Backward compatibility
Additive and off by default —
interpolate_delays=Falseandmax_delay=Nonereproduce existing behaviour exactly; no call sites change. The non-interpolated forward path is byte-identical (round/math.ceilequal the previousjnp.rint/jnp.ceilon concrete inputs).Validation
tests/test_network_dynamics/test_interpolated_delays.pycovers: forward with interpolation on/off across sizes 2/8/32; theinterpolate_delaysrequires-rollguard;max_delaydefault vs explicit and the over-bound guard;with_delays; gradient vs central finite difference (incl. underjit); zero gradient for the nearest-step path; jitted conduction-speed recovery; and a heterogeneous, multi-parameter delay equation. A figures + runnable-demo comment follows below.Notes
The delay gradient is nonzero only when the delays (and hence the interpolation fraction) are recomputed under tracing — i.e. when
prepareruns inside the differentiated region.