diff --git a/src/tvboptim/experimental/network_dynamics/core/network.py b/src/tvboptim/experimental/network_dynamics/core/network.py index 9026025..bb71eb7 100644 --- a/src/tvboptim/experimental/network_dynamics/core/network.py +++ b/src/tvboptim/experimental/network_dynamics/core/network.py @@ -4,6 +4,7 @@ (ODE/DDE/SDE/SDDE) through composition rather than inheritance. """ +import math import warnings from typing import TYPE_CHECKING, Dict, List, Optional, Union @@ -555,8 +556,8 @@ def _get_initial_history(self, dt: float) -> Optional[jnp.ndarray]: where n_steps = ceil(max_delay / dt) """ n_steps = max( - 1, int(jnp.ceil(self.max_delay / dt)) - ) # at least 1 step (case: speed = inf) + 1, math.ceil(float(self.max_delay) / dt) + ) # at least 1 step (case: speed = inf); float() keeps it static under jit return jnp.broadcast_to( self.initial_state[None, :, :], (n_steps, self.initial_state.shape[0], self.initial_state.shape[1]), diff --git a/src/tvboptim/experimental/network_dynamics/coupling/base.py b/src/tvboptim/experimental/network_dynamics/coupling/base.py index ba198fd..5c5941c 100644 --- a/src/tvboptim/experimental/network_dynamics/coupling/base.py +++ b/src/tvboptim/experimental/network_dynamics/coupling/base.py @@ -757,6 +757,8 @@ class DelayedCoupling(PrePostCoupling): optimization with small dt / large history. Memory: O(T). - "preallocated": Pre-allocates full simulation buffer. Best forward-pass performance but gradients degrade. Memory: O(T + simulation_steps). + interpolate_delays : bool, default False + If True ("roll" strategy only), read delayed states by linear interpolation between the two bracketing history steps instead of snapping to the nearest step. Makes the coupling differentiable w.r.t. the continuous delay (e.g. for optimising conduction speed). The delay gradient is only nonzero when the delays (and hence the interpolation fraction) are recomputed under tracing, i.e. when ``prepare`` runs inside the differentiated region. **kwargs Passed to parent class (incoming_states, local_states, params) """ @@ -764,6 +766,7 @@ class DelayedCoupling(PrePostCoupling): def __init__( self, buffer_strategy: BufferStrategy = "roll", + interpolate_delays: bool = False, **kwargs, ): super().__init__(**kwargs) @@ -774,7 +777,13 @@ def __init__( f"Must be one of: 'roll', 'circular', 'preallocated'" ) + if interpolate_delays and buffer_strategy != "roll": + raise ValueError( + "interpolate_delays=True is only supported with buffer_strategy='roll'" + ) + self.buffer_strategy = buffer_strategy + self.interpolate_delays = interpolate_delays def prepare(self, network, dt: float, t0: float, t1: float) -> Tuple[Bunch, Bunch]: """Standard preparation for delayed coupling. @@ -798,7 +807,7 @@ def prepare(self, network, dt: float, t0: float, t1: float) -> Tuple[Bunch, Bunc # Convert delay times to discrete timesteps delays_dense = _ensure_dense(graph.delays) - max_delay_steps = int(jnp.rint(graph.max_delay / dt)) + max_delay_steps = int(round(float(graph.max_delay) / dt)) delay_steps = jnp.rint(delays_dense / dt).astype(jnp.int32) base_history_length = max_delay_steps + 1 @@ -835,6 +844,21 @@ def prepare(self, network, dt: float, t0: float, t1: float) -> Tuple[Bunch, Bunc bcoo_shape=bcoo_shape, use_sparse_incoming=use_sparse_incoming, ) + + if self.interpolate_delays: + # Bracket each delay by its two neighbouring history steps and store + # the fraction so compute() can linearly blend them. + delay_f = delays_dense / dt + delay_lo = jnp.floor(delay_f) + delay_frac = delay_f - delay_lo # in [0, 1) + idx_lo = (max_delay_steps - delay_lo).astype(jnp.int32) + idx_hi = idx_lo - 1 + idx_lo = jnp.clip(idx_lo, 0, max_delay_steps) + idx_hi = jnp.clip(idx_hi, 0, max_delay_steps) + coupling_data.delay_indices = idx_lo + coupling_data.delay_indices_hi = idx_hi + coupling_data.delay_frac = delay_frac + coupling_state = Bunch(history=history_init) elif self.buffer_strategy == "circular": @@ -938,10 +962,20 @@ def compute( # Extract delayed states using computed indices # Result: delayed_states[i, j, k] = history[read_indices[j,k], i, k] # i.e., incoming state i from source node k, delayed by τ_jk, going to target j - delayed_states = jnp.transpose( - coupling_state.history[read_indices, :, coupling_data.state_indices], - (2, 0, 1), # Reorder to [n_incoming, n_nodes_target, n_nodes_source] - ) + def _gather(idx): + return jnp.transpose( + coupling_state.history[idx, :, coupling_data.state_indices], + (2, 0, 1), # Reorder to [n_incoming, n_nodes_target, n_nodes_source] + ) + + if self.buffer_strategy == "roll" and self.interpolate_delays: + # Interpolated delays: linearly blend the two bracketing history steps. + frac = coupling_data.delay_frac + delayed_states = (1.0 - frac) * _gather(read_indices) + frac * _gather( + coupling_data.delay_indices_hi + ) + else: + delayed_states = _gather(read_indices) # Extract local states local_states = state[coupling_data.local_indices] diff --git a/src/tvboptim/experimental/network_dynamics/graph/base.py b/src/tvboptim/experimental/network_dynamics/graph/base.py index e7cb161..a9b7761 100644 --- a/src/tvboptim/experimental/network_dynamics/graph/base.py +++ b/src/tvboptim/experimental/network_dynamics/graph/base.py @@ -402,6 +402,7 @@ class DenseDelayGraph(DenseGraph): delays: Delay matrix [n_nodes, n_nodes] in same units as integration time region_labels: Optional sequence of region labels (list, tuple, or array). If None, defaults to ['Region_0', 'Region_1', ...] symmetric: Whether to treat as symmetric (None = auto-detect) + max_delay: Optional maximum delay used to size the history buffer. Pass it explicitly to allow ``delays`` to be a JAX tracer (e.g. tract_length / speed); if None, derived from max(delays). """ def __init__( @@ -410,6 +411,7 @@ def __init__( delays: jnp.ndarray, region_labels: Optional[Sequence[str]] = None, symmetric: Optional[bool] = None, + max_delay: Optional[float] = None, ): # Process delays first (needed for verify method) self._delays = jnp.asarray(delays) @@ -421,8 +423,13 @@ def __init__( f"Delay matrix shape {self._delays.shape} must match weight matrix shape {weights_array.shape}" ) - # Compute and store max_delay to avoid accessing delays during pytree transformations - self._max_delay = float(jnp.max(self._delays)) + # max_delay sizes the (static) history buffer. Passing it explicitly lets + # ``delays`` be a JAX tracer; otherwise it is derived from the delays. + if max_delay is not None: + self._max_delay = float(max_delay) + else: + self._max_delay = float(jnp.max(self._delays)) + self._check_delays_within_buffer(self._delays) # Initialize parent Graph (pass region_labels) super().__init__(weights, region_labels=region_labels, symmetric=symmetric) @@ -441,6 +448,40 @@ def max_delay(self) -> float: """Maximum delay in the network.""" return self._max_delay + def _check_delays_within_buffer(self, delays) -> None: + """Raise if *concrete* delays exceed ``max_delay`` — the fixed-size history + buffer (``ceil(max_delay/dt)`` steps) would silently truncate them. Skipped + for traced delays (under jit/grad the check cannot run; size ``max_delay`` + for the slowest conduction speed you will explore, e.g. max(length) / v_min).""" + if isinstance(delays, jax.core.Tracer): + return + dmax = float(jnp.max(delays)) + if dmax > self._max_delay: + raise ValueError( + f"Largest delay ({dmax:g}) exceeds max_delay ({self._max_delay:g}); the " + f"fixed history buffer would truncate it. Set max_delay >= max(delays) " + f"(e.g. max(tract_length) / v_min when optimising conduction speed)." + ) + + def with_delays(self, delays: jnp.ndarray) -> "DenseDelayGraph": + """Return a copy with ``delays`` replaced, reusing the static structure + (``max_delay``, labels, symmetry) without re-running ``__init__``/``verify``. + + This is the jit/grad-safe way to vary the delays — e.g. + ``graph.with_delays(lengths / speed)`` inside an optimised loss — so the + history-buffer length stays fixed by ``max_delay`` while the delays may + be JAX tracers. Build the graph once (concrete) to set ``max_delay``, + then call this to update the delays. + """ + delays = jnp.asarray(delays) + if delays.shape != self._delays.shape: + raise ValueError( + f"with_delays expects delays of shape {self._delays.shape}, got {delays.shape}" + ) + self._check_delays_within_buffer(delays) + children, aux_data = self.tree_flatten() + return type(self).tree_unflatten(aux_data, (children[0], delays)) + def verify(self, verbose: bool = True) -> bool: """Verify delay graph structure. diff --git a/tests/test_network_dynamics/test_interpolated_delays.py b/tests/test_network_dynamics/test_interpolated_delays.py new file mode 100644 index 0000000..1d3e9a6 --- /dev/null +++ b/tests/test_network_dynamics/test_interpolated_delays.py @@ -0,0 +1,278 @@ +"""Tests for differentiable (interpolated) delays and DenseDelayGraph(max_delay). + +Covers the two features added in this change: + +* ``DelayedCoupling(interpolate_delays=True)`` — reads delayed states by linear + interpolation between the two bracketing history steps instead of snapping to + the nearest step, making the coupling differentiable w.r.t. the continuous + delay (and hence w.r.t. conduction speed, since ``delay = length / speed``). +* ``DenseDelayGraph(max_delay=...)`` — sizes the (static) history buffer + explicitly so the ``delays`` matrix may be a JAX tracer, enabling + gradient-based optimisation while the buffer length stays static. + +The whole path (graph construction + ``prepare`` + ``solve``) is exercised under +both bare ``jax.grad`` and ``jax.jit``. +""" + +import unittest + +import jax +import jax.numpy as jnp +import numpy as np + +jax.config.update("jax_enable_x64", True) + +import optax + +from tvboptim.experimental.network_dynamics import Network, solve +from tvboptim.experimental.network_dynamics.coupling import DelayedLinearCoupling +from tvboptim.experimental.network_dynamics.dynamics.tvb import Generic2dOscillator +from tvboptim.experimental.network_dynamics.graph import DenseDelayGraph +from tvboptim.experimental.network_dynamics.solvers import Heun + +DT = 0.5 +T1 = 40.0 + + +def _osc(): + return Generic2dOscillator( + a=-1.5, b=-15.0, d=0.015, tau=4.0, INITIAL_STATE=(0.1, 0.1) + ) + + +def _conn(n, seed): + """Random weights + symmetric tract-length matrix of size n.""" + rng = np.random.default_rng(seed) + w = np.where(np.eye(n), 0.0, rng.uniform(0.0, 0.1, (n, n))) + length = rng.uniform(10.0, 80.0, (n, n)) + length = 0.5 * (length + length.T) + np.fill_diagonal(length, 0.0) + return jnp.asarray(w), jnp.asarray(length) + + +def _net_from_graph(graph, interpolate, G=0.3): + coup = DelayedLinearCoupling( + incoming_states="V", G=G, interpolate_delays=interpolate + ) + return Network(_osc(), {"delayed": coup}, graph) + + +def _run_graph(graph, interpolate, G=0.3): + return solve( + _net_from_graph(graph, interpolate, G), Heun(), t0=0.0, t1=T1, dt=DT + ).ys + + +def _run(weights, delays, interpolate, max_delay, G=0.3): + return _run_graph( + DenseDelayGraph(weights, delays, max_delay=max_delay), interpolate, G + ) + + +class TestForward(unittest.TestCase): + """Forward simulation with interpolation across network sizes.""" + + def test_finite_and_shaped(self): + for n in (2, 8, 32): + with self.subTest(n=n): + w, length = _conn(n, seed=n) + delays = length / 3.0 + md = float(jnp.max(delays)) + ys = _run(w, delays, interpolate=True, max_delay=md) + self.assertEqual(ys.shape[-1], n) + self.assertTrue(bool(jnp.all(jnp.isfinite(ys)))) + + def test_interpolation_is_bounded_correction(self): + for n in (2, 8, 32): + with self.subTest(n=n): + w, length = _conn(n, seed=100 + n) + delays = length / 3.0 + md = float(jnp.max(delays)) + snap = _run(w, delays, interpolate=False, max_delay=md) + interp = _run(w, delays, interpolate=True, max_delay=md) + self.assertTrue(bool(jnp.all(jnp.isfinite(interp)))) + diff = float(jnp.max(jnp.abs(interp - snap))) + scale = float(jnp.mean(jnp.abs(snap))) + 1e-9 + self.assertLess(diff, 5.0 * scale) # a bounded correction + + def test_interpolation_is_active_off_grid(self): + """Half-step-offset delays (frac==0.5) must differ from nearest-step.""" + w, length = _conn(8, seed=42) + snapped = jnp.round((length / 3.0) / DT) + delays = (snapped + 0.5) * DT # exactly half a step off the grid + md = float(jnp.max(delays)) + snap = _run(w, delays, interpolate=False, max_delay=md) + interp = _run(w, delays, interpolate=True, max_delay=md) + self.assertGreater(float(jnp.max(jnp.abs(interp - snap))), 0.0) + + def test_reduces_to_snapping_on_grid(self): + """When every delay is an exact multiple of dt, frac==0 -> identical to snap.""" + w, length = _conn(8, seed=7) + delays = jnp.round((length / 3.0) / DT) * DT # exact grid points + md = float(jnp.max(delays)) + snap = _run(w, delays, interpolate=False, max_delay=md) + interp = _run(w, delays, interpolate=True, max_delay=md) + np.testing.assert_allclose( + np.asarray(interp), np.asarray(snap), rtol=0, atol=1e-12 + ) + + +class TestConstructorGuards(unittest.TestCase): + def test_interpolate_requires_roll(self): + for strategy in ("circular", "preallocated"): + with self.subTest(strategy=strategy): + with self.assertRaises(ValueError): + DelayedLinearCoupling( + incoming_states="V", + buffer_strategy=strategy, + interpolate_delays=True, + ) + + def test_non_roll_without_interpolate_allowed(self): + for strategy in ("roll", "circular", "preallocated"): + with self.subTest(strategy=strategy): + DelayedLinearCoupling(incoming_states="V", buffer_strategy=strategy) + + def test_roll_with_interpolate_allowed(self): + DelayedLinearCoupling( + incoming_states="V", buffer_strategy="roll", interpolate_delays=True + ) + + +class TestMaxDelay(unittest.TestCase): + def test_default_derived_from_delays(self): + w, length = _conn(6, seed=3) + delays = length / 3.0 + graph = DenseDelayGraph(w, delays) + self.assertAlmostEqual(graph.max_delay, float(jnp.max(delays)), places=10) + + def test_explicit_overrides_and_sizes_buffer(self): + w, length = _conn(6, seed=3) + delays = length / 3.0 + explicit = float(jnp.max(delays)) * 2.0 + graph = DenseDelayGraph(w, delays, max_delay=explicit) + self.assertAlmostEqual(graph.max_delay, explicit, places=10) + # An explicit (larger-than-needed) max_delay still simulates correctly. + ys = _run(w, delays, interpolate=True, max_delay=explicit) + self.assertTrue(bool(jnp.all(jnp.isfinite(ys)))) + + def test_tracer_delays_require_max_delay(self): + w, length = _conn(6, seed=3) + # Without max_delay, a tracer ``delays`` cannot size the static buffer. + with self.assertRaises(Exception): + jax.grad(lambda v: jnp.sum(DenseDelayGraph(w, length / v).delays))( + jnp.asarray(3.0) + ) + # With max_delay set, it is fine and differentiable. + md = float(jnp.max(length)) / 2.0 + g = jax.grad( + lambda v: jnp.sum(DenseDelayGraph(w, length / v, max_delay=md).delays) + )(jnp.asarray(3.0)) + self.assertTrue(np.isfinite(float(g))) + self.assertNotEqual(float(g), 0.0) + + +class TestWithDelays(unittest.TestCase): + def test_replaces_delays_preserving_static_structure(self): + w, length = _conn(6, seed=9) + g = DenseDelayGraph(w, length / 3.0, max_delay=50.0) + g2 = g.with_delays(length / 5.0) + np.testing.assert_allclose(np.asarray(g2.delays), np.asarray(length / 5.0)) + self.assertEqual(g2.max_delay, 50.0) # static buffer bound preserved + self.assertEqual(g2.n_nodes, g.n_nodes) + # original graph is unchanged (immutable update) + np.testing.assert_allclose(np.asarray(g.delays), np.asarray(length / 3.0)) + + def test_with_delays_is_jit_safe(self): + w, length = _conn(6, seed=9) + g = DenseDelayGraph(w, length / 3.0, max_delay=50.0) + # rebuilding via with_delays inside jit must not trigger __init__/verify + delays = jax.jit(lambda v: g.with_delays(length / v).delays)(jnp.asarray(4.0)) + np.testing.assert_allclose(np.asarray(delays), np.asarray(length / 4.0)) + + +class TestDifferentiability(unittest.TestCase): + """Gradient of a loss w.r.t. conduction speed, validated against finite diff.""" + + def setUp(self): + self.w, self.length = _conn(6, seed=0) + self.max_delay = float(jnp.max(self.length)) / 2.0 # valid for speed >= 2 + self.true_speed = 3.0 + # Build the graph once (concrete) to fix the static buffer via max_delay, + # then vary the delays with with_delays() — never reconstructed under jit. + self.graph = DenseDelayGraph( + self.w, self.length / self.true_speed, max_delay=self.max_delay + ) + self.target = _run_graph(self.graph, True, G=0.4) + + def _loss(self, speed, interpolate): + g = self.graph.with_delays(self.length / speed) + return jnp.mean((_run_graph(g, interpolate, G=0.4) - self.target) ** 2) + + def test_grad_matches_finite_difference(self): + s0 = 4.0 + ad = float(jax.grad(lambda s: self._loss(s, True))(jnp.asarray(s0))) + fd = float((self._loss(s0 + 1e-3, True) - self._loss(s0 - 1e-3, True)) / 2e-3) + self.assertNotEqual(ad, 0.0) + self.assertLess(abs(ad - fd) / (abs(fd) + 1e-30), 1e-3) + + def test_snap_gives_zero_gradient(self): + ad = float(jax.grad(lambda s: self._loss(s, False))(jnp.asarray(4.0))) + self.assertLess(abs(ad), 1e-12) # nearest-step is piecewise constant in speed + + def test_jit_grad_matches_finite_difference(self): + s0 = 4.0 + ad = float(jax.jit(jax.grad(lambda s: self._loss(s, True)))(jnp.asarray(s0))) + fd = float((self._loss(s0 + 1e-3, True) - self._loss(s0 - 1e-3, True)) / 2e-3) + self.assertLess(abs(ad - fd) / (abs(fd) + 1e-30), 1e-3) + + def test_jitted_optimization_recovers_speed(self): + def loss(s): + return self._loss(s, True) + + grad = jax.jit(jax.grad(loss)) + v = jnp.asarray(4.0) + opt = optax.adam(0.1) + st = opt.init(v) + l0 = float(loss(v)) + for _ in range(60): + updates, st = opt.update(grad(v), st) + v = optax.apply_updates(v, updates) + self.assertLess(float(loss(v)), l0) + # genuinely recovered, not merely nudged toward the target + self.assertLess(abs(float(v) - self.true_speed), 0.2) + + +class TestHeterogeneousDelayEquation(unittest.TestCase): + """delays = offset + L / speed_per_source_node -> heterogeneous, multi-parameter.""" + + def setUp(self): + self.w, self.length = _conn(6, seed=11) + self.n = 6 + self.max_delay = float(jnp.max(self.length)) / 1.5 + self.graph = DenseDelayGraph( + self.w, self.length / 3.0, max_delay=self.max_delay + ) + + def _delays(self, theta): + return theta["offset"] + self.length / theta["speed"][None, :] + + def _loss(self, theta, target): + ys = _run_graph(self.graph.with_delays(self._delays(theta)), True, G=0.4) + return jnp.mean((ys - target) ** 2) + + def test_grad_flows_through_vector_speed_and_offset(self): + true = { + "speed": jnp.full(self.n, 3.0).at[: self.n // 2].set(2.0), + "offset": jnp.asarray(0.5), + } + target = _run(self.w, self._delays(true), True, self.max_delay, G=0.4) + th0 = {"speed": jnp.full(self.n, 3.0), "offset": jnp.asarray(0.0)} + g = jax.jit(jax.grad(lambda t: self._loss(t, target)))(th0) + self.assertTrue(bool(jnp.all(jnp.isfinite(g["speed"])))) + self.assertGreater(float(jnp.max(jnp.abs(g["speed"]))), 0.0) + self.assertTrue(np.isfinite(float(g["offset"]))) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/tvbo/__init__.py b/tests/tvbo/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/tvbo/test_differentiable_delays_experiment.py b/tests/tvbo/test_differentiable_delays_experiment.py new file mode 100644 index 0000000..777e59d --- /dev/null +++ b/tests/tvbo/test_differentiable_delays_experiment.py @@ -0,0 +1,151 @@ +"""End-to-end differentiable-delays experiment driven through tvbo. + +The whole simulation — model, delayed (interpolated) coupling, connectivity, +a conduction-speed exploration, and a conduction-speed optimization — is +declared inline as a single YAML string and executed with + + SimulationExperiment.from_string(exp_yaml).run("tvboptim") + +The delayed coupling sets ``interpolate_delays: true``; tvbo's tvboptim code +generator emits ``DelayedCoupling(interpolate_delays=True, ...)`` (the feature +added on the tvboptim side), so the conduction delays are differentiable and +the run is jit-compatible. + +Requires the optional ``tvbo`` package built with the ``interpolate_delays`` +coupling slot + codegen (skipped otherwise); ``tvbo`` is not a CI dependency. +The optimization test additionally needs tvbo to expose ``conduction_speed`` as +an optimizable parameter (it is currently excluded from +``SimulationExperiment.collect_state``); until then it is skipped. +""" + +import unittest + +import numpy as np +import pytest + +pytest.importorskip("tvbo") + +from tvbo import SimulationExperiment # noqa: E402 + +try: + from tvbo.datamodel.pydantic import Coupling as _Coupling + + _HAS_INTERPOLATE = "interpolate_delays" in getattr(_Coupling, "model_fields", {}) +except Exception: + _HAS_INTERPOLATE = False + +exp_yaml = """ +label: Differentiable delays — conduction-speed exploration + optimization +dynamics: + iri: tvbo:Generic2dOscillator +network: + number_of_nodes: 4 + coupling: + c_glob: + iri: tvbo:Linear + delayed: true + interpolate_delays: true + parameters: + G: {value: 0.5} + parameters: + conduction_speed: {value: 3.0, unit: mm_per_ms} + nodes: + - {id: 0, label: R0, dynamics: Generic2dOscillator} + - {id: 1, label: R1, dynamics: Generic2dOscillator} + - {id: 2, label: R2, dynamics: Generic2dOscillator} + - {id: 3, label: R3, dynamics: Generic2dOscillator} + edges: + - {source: 0, target: 1, directed: false, parameters: {weight: {value: 0.5}, length: {value: 30.0, unit: mm}}} + - {source: 1, target: 2, directed: false, parameters: {weight: {value: 0.4}, length: {value: 45.0, unit: mm}}} + - {source: 2, target: 3, directed: false, parameters: {weight: {value: 0.3}, length: {value: 60.0, unit: mm}}} + - {source: 0, target: 3, directed: false, parameters: {weight: {value: 0.2}, length: {value: 75.0, unit: mm}}} +integration: + method: Heun + step_size: 0.5 + duration: 50 +observations: + activity: + label: Mean V activity + source: [V] +explorations: + speed_sweep: + space: + - parameter: conduction_speed + domain: {lo: 2.0, hi: 5.0, n: 3} +optimizations: + speed_fit: + loss: + function: mse + arguments: + - {name: simulated, value: observations.activity.data} + - {name: target, value: 0.0} + stages: + - name: fit_speed + free_parameters: + - parameter: conduction_speed + algorithm: adam + learning_rate: 0.2 + max_iterations: 5 +""" + + +def _is_conduction_speed_unsupported(exc: Exception) -> bool: + """tvbo builds that don't expose conduction_speed fail with it missing from state.""" + return "conduction_speed" in str(exc) + + +@unittest.skipUnless( + _HAS_INTERPOLATE, "tvbo build lacks the interpolate_delays coupling slot" +) +class TestDifferentiableDelaysExperiment(unittest.TestCase): + def test_codegen_emits_interpolate_delays(self): + """The tvboptim code generator wires interpolate_delays into the coupling.""" + code = SimulationExperiment.from_string(exp_yaml).render_code("tvboptim") + self.assertIn("interpolate_delays=True", code) + + def test_run_tvboptim_forward(self): + """A forward run integrates the interpolated-delay network to a finite series.""" + result = SimulationExperiment.from_string(exp_yaml).run( + "tvboptim", mode="simulation" + ) + ys = np.asarray(getattr(result.integration, "data", result.integration)) + self.assertEqual(ys.shape[-1], 4) # 4 nodes + self.assertTrue(bool(np.all(np.isfinite(ys)))) + + def test_conduction_speed_exploration(self): + """The experiment sweeps conduction_speed over a 3-point grid (interpolated delays).""" + result = SimulationExperiment.from_string(exp_yaml).run( + "tvboptim", mode="exploration" + ) + axis = result.explorations.speed_sweep.axes[0] + self.assertEqual(axis.name, "conduction_speed") + self.assertEqual(len(np.asarray(axis.explored_values)), 3) + + def test_run_tvboptim_optimizes_conduction_speed(self): + """from_string(exp_yaml).run("tvboptim") builds and acts on the conduction-speed gradient. + + The default run executes every stage, including the gradient-based + optimization of conduction_speed. The gradient flows speed -> delays -> + the interpolated coupling. Requires tvbo to expose conduction_speed as an + optimizable parameter (currently excluded from collect_state and baked + into the delays at codegen time); until then the run raises with + conduction_speed missing from the state and we skip. + """ + try: + result = SimulationExperiment.from_string(exp_yaml).run("tvboptim") + except Exception as exc: # noqa: BLE001 + if _is_conduction_speed_unsupported(exc): + self.skipTest( + "tvbo does not yet expose conduction_speed as an optimizable parameter: " + "include it in SimulationExperiment.collect_state and recompute " + "delays = lengths / conduction_speed from that state leaf at runtime " + "(the tvboptim interpolate_delays path then makes the gradient flow)." + ) + raise + self.assertIsNotNone( + result.optimizations.speed_fit + ) # gradient built; optimizer stepped + + +if __name__ == "__main__": + unittest.main()