Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/tvboptim/experimental/network_dynamics/core/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
(ODE/DDE/SDE/SDDE) through composition rather than inheritance.
"""

import math
import warnings
from typing import TYPE_CHECKING, Dict, List, Optional, Union

Expand Down Expand Up @@ -555,8 +556,8 @@ def _get_initial_history(self, dt: float) -> Optional[jnp.ndarray]:
where n_steps = ceil(max_delay / dt)
"""
n_steps = max(
1, int(jnp.ceil(self.max_delay / dt))
) # at least 1 step (case: speed = inf)
1, math.ceil(float(self.max_delay) / dt)
) # at least 1 step (case: speed = inf); float() keeps it static under jit
return jnp.broadcast_to(
self.initial_state[None, :, :],
(n_steps, self.initial_state.shape[0], self.initial_state.shape[1]),
Expand Down
44 changes: 39 additions & 5 deletions src/tvboptim/experimental/network_dynamics/coupling/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,13 +757,16 @@ class DelayedCoupling(PrePostCoupling):
optimization with small dt / large history. Memory: O(T).
- "preallocated": Pre-allocates full simulation buffer. Best forward-pass
performance but gradients degrade. Memory: O(T + simulation_steps).
interpolate_delays : bool, default False
If True ("roll" strategy only), read delayed states by linear interpolation between the two bracketing history steps instead of snapping to the nearest step. Makes the coupling differentiable w.r.t. the continuous delay (e.g. for optimising conduction speed). The delay gradient is only nonzero when the delays (and hence the interpolation fraction) are recomputed under tracing, i.e. when ``prepare`` runs inside the differentiated region.
**kwargs
Passed to parent class (incoming_states, local_states, params)
"""

def __init__(
self,
buffer_strategy: BufferStrategy = "roll",
interpolate_delays: bool = False,
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -774,7 +777,13 @@ def __init__(
f"Must be one of: 'roll', 'circular', 'preallocated'"
)

if interpolate_delays and buffer_strategy != "roll":
raise ValueError(
"interpolate_delays=True is only supported with buffer_strategy='roll'"
)

self.buffer_strategy = buffer_strategy
self.interpolate_delays = interpolate_delays

def prepare(self, network, dt: float, t0: float, t1: float) -> Tuple[Bunch, Bunch]:
"""Standard preparation for delayed coupling.
Expand All @@ -798,7 +807,7 @@ def prepare(self, network, dt: float, t0: float, t1: float) -> Tuple[Bunch, Bunc

# Convert delay times to discrete timesteps
delays_dense = _ensure_dense(graph.delays)
max_delay_steps = int(jnp.rint(graph.max_delay / dt))
max_delay_steps = int(round(float(graph.max_delay) / dt))
delay_steps = jnp.rint(delays_dense / dt).astype(jnp.int32)
base_history_length = max_delay_steps + 1

Expand Down Expand Up @@ -835,6 +844,21 @@ def prepare(self, network, dt: float, t0: float, t1: float) -> Tuple[Bunch, Bunc
bcoo_shape=bcoo_shape,
use_sparse_incoming=use_sparse_incoming,
)

if self.interpolate_delays:
# Bracket each delay by its two neighbouring history steps and store
# the fraction so compute() can linearly blend them.
delay_f = delays_dense / dt
delay_lo = jnp.floor(delay_f)
delay_frac = delay_f - delay_lo # in [0, 1)
idx_lo = (max_delay_steps - delay_lo).astype(jnp.int32)
idx_hi = idx_lo - 1
idx_lo = jnp.clip(idx_lo, 0, max_delay_steps)
idx_hi = jnp.clip(idx_hi, 0, max_delay_steps)
coupling_data.delay_indices = idx_lo
coupling_data.delay_indices_hi = idx_hi
coupling_data.delay_frac = delay_frac

coupling_state = Bunch(history=history_init)

elif self.buffer_strategy == "circular":
Expand Down Expand Up @@ -938,10 +962,20 @@ def compute(
# Extract delayed states using computed indices
# Result: delayed_states[i, j, k] = history[read_indices[j,k], i, k]
# i.e., incoming state i from source node k, delayed by τ_jk, going to target j
delayed_states = jnp.transpose(
coupling_state.history[read_indices, :, coupling_data.state_indices],
(2, 0, 1), # Reorder to [n_incoming, n_nodes_target, n_nodes_source]
)
def _gather(idx):
return jnp.transpose(
coupling_state.history[idx, :, coupling_data.state_indices],
(2, 0, 1), # Reorder to [n_incoming, n_nodes_target, n_nodes_source]
)

if self.buffer_strategy == "roll" and self.interpolate_delays:
# Interpolated delays: linearly blend the two bracketing history steps.
frac = coupling_data.delay_frac
delayed_states = (1.0 - frac) * _gather(read_indices) + frac * _gather(
coupling_data.delay_indices_hi
)
else:
delayed_states = _gather(read_indices)

# Extract local states
local_states = state[coupling_data.local_indices]
Expand Down
45 changes: 43 additions & 2 deletions src/tvboptim/experimental/network_dynamics/graph/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,7 @@ class DenseDelayGraph(DenseGraph):
delays: Delay matrix [n_nodes, n_nodes] in same units as integration time
region_labels: Optional sequence of region labels (list, tuple, or array). If None, defaults to ['Region_0', 'Region_1', ...]
symmetric: Whether to treat as symmetric (None = auto-detect)
max_delay: Optional maximum delay used to size the history buffer. Pass it explicitly to allow ``delays`` to be a JAX tracer (e.g. tract_length / speed); if None, derived from max(delays).
"""

def __init__(
Expand All @@ -410,6 +411,7 @@ def __init__(
delays: jnp.ndarray,
region_labels: Optional[Sequence[str]] = None,
symmetric: Optional[bool] = None,
max_delay: Optional[float] = None,
):
# Process delays first (needed for verify method)
self._delays = jnp.asarray(delays)
Expand All @@ -421,8 +423,13 @@ def __init__(
f"Delay matrix shape {self._delays.shape} must match weight matrix shape {weights_array.shape}"
)

# Compute and store max_delay to avoid accessing delays during pytree transformations
self._max_delay = float(jnp.max(self._delays))
# max_delay sizes the (static) history buffer. Passing it explicitly lets
# ``delays`` be a JAX tracer; otherwise it is derived from the delays.
if max_delay is not None:
self._max_delay = float(max_delay)
else:
self._max_delay = float(jnp.max(self._delays))
self._check_delays_within_buffer(self._delays)

# Initialize parent Graph (pass region_labels)
super().__init__(weights, region_labels=region_labels, symmetric=symmetric)
Expand All @@ -441,6 +448,40 @@ def max_delay(self) -> float:
"""Maximum delay in the network."""
return self._max_delay

def _check_delays_within_buffer(self, delays) -> None:
"""Raise if *concrete* delays exceed ``max_delay`` — the fixed-size history
buffer (``ceil(max_delay/dt)`` steps) would silently truncate them. Skipped
for traced delays (under jit/grad the check cannot run; size ``max_delay``
for the slowest conduction speed you will explore, e.g. max(length) / v_min)."""
if isinstance(delays, jax.core.Tracer):
return
dmax = float(jnp.max(delays))
if dmax > self._max_delay:
raise ValueError(
f"Largest delay ({dmax:g}) exceeds max_delay ({self._max_delay:g}); the "
f"fixed history buffer would truncate it. Set max_delay >= max(delays) "
f"(e.g. max(tract_length) / v_min when optimising conduction speed)."
)

def with_delays(self, delays: jnp.ndarray) -> "DenseDelayGraph":
"""Return a copy with ``delays`` replaced, reusing the static structure
(``max_delay``, labels, symmetry) without re-running ``__init__``/``verify``.

This is the jit/grad-safe way to vary the delays — e.g.
``graph.with_delays(lengths / speed)`` inside an optimised loss — so the
history-buffer length stays fixed by ``max_delay`` while the delays may
be JAX tracers. Build the graph once (concrete) to set ``max_delay``,
then call this to update the delays.
"""
delays = jnp.asarray(delays)
if delays.shape != self._delays.shape:
raise ValueError(
f"with_delays expects delays of shape {self._delays.shape}, got {delays.shape}"
)
self._check_delays_within_buffer(delays)
children, aux_data = self.tree_flatten()
return type(self).tree_unflatten(aux_data, (children[0], delays))

def verify(self, verbose: bool = True) -> bool:
"""Verify delay graph structure.

Expand Down
Loading
Loading