diff --git a/.gitignore b/.gitignore index 60b17e9..309df45 100644 --- a/.gitignore +++ b/.gitignore @@ -70,4 +70,7 @@ docs/_styles-quartodoc.css Thumbs.db # AI Skills -src/tvboptim/claude/skills \ No newline at end of file +src/tvboptim/claude/skills + +# Stray cache dir from running doc notebooks outside their own directory +/cache/ diff --git a/docs/_quarto.yml b/docs/_quarto.yml index 3083559..47c789c 100644 --- a/docs/_quarto.yml +++ b/docs/_quarto.yml @@ -25,6 +25,7 @@ website: - text: "Workflows" menu: - file: workflows/RWW.qmd + - file: workflows/TBPTT.qmd - file: workflows/JR.qmd - file: workflows/EI_Tuning.qmd - file: workflows/Hopf_Pareto_ParallelOpt.qmd @@ -42,6 +43,7 @@ website: menu: - file: advanced/buffer_strategies.qmd - file: advanced/gradient_checkpointing.qmd + - file: advanced/streaming_reductions.qmd - file: advanced/coupling_freezing.qmd - file: advanced/subspace_coupling.qmd - text: "Reference" @@ -204,3 +206,12 @@ quartodoc: desc: The unified interface for preparing experiments and networks for simulation contents: - tvbo.prepare + + - title: Analysis + desc: Characterizing the dynamical regime + package: tvboptim.experimental.network_dynamics.analysis + contents: + - lyapunov + - lyapunov_spectrum + - adiabatic_scan + - AdiabaticScanResult diff --git a/docs/advanced/cache/gradient_checkpointing_benchmark/block_size_sweep.pkl b/docs/advanced/cache/gradient_checkpointing_benchmark/block_size_sweep.pkl new file mode 100644 index 0000000..7643f1d Binary files /dev/null and b/docs/advanced/cache/gradient_checkpointing_benchmark/block_size_sweep.pkl differ diff --git a/docs/advanced/cache/gradient_checkpointing_benchmark/checkpoint_sweep.pkl b/docs/advanced/cache/gradient_checkpointing_benchmark/checkpoint_sweep.pkl deleted file mode 100644 index 45bea70..0000000 Binary files a/docs/advanced/cache/gradient_checkpointing_benchmark/checkpoint_sweep.pkl and /dev/null differ diff --git a/docs/advanced/gradient_checkpointing.qmd b/docs/advanced/gradient_checkpointing.qmd index 049748d..54b3c10 100644 --- a/docs/advanced/gradient_checkpointing.qmd +++ b/docs/advanced/gradient_checkpointing.qmd @@ -50,13 +50,13 @@ except ImportError: The standard remedy is **gradient checkpointing**: instead of saving every step's activations for the backward pass, save only a sparse subset and recompute the missing ones on demand. TVB-Optim implements this for the native -solver path as a single optional knob on `NativeSolver`: +solver path as a single optional knob on `NativeSolver`, `block_size`: ```python -solver = Heun(checkpoint_every=256) +solver = Heun(block_size=256) ``` -When `checkpoint_every` is `None` (the default) the integration runs as a +When `block_size` is `None` (the default) the integration runs as a single `jax.lax.scan` exactly as before — there is no overhead and no behaviour change. When set to an integer `K`, the scan is split into an outer scan over blocks of `K` steps wrapped in `jax.checkpoint`, with an inner scan running @@ -66,17 +66,34 @@ overhead — typically 1.3–1.7× depending on workload**, since backward is usually already several times more expensive than forward and one extra forward pass adds only a fraction to that total. Forward time is unchanged. The optimum for memory minimisation lies near `K ≈ √n_steps`. +`block_size` is the one block unit for the solver's streaming features (it was +formerly named `checkpoint_every`). Two consequences matter for this notebook: + +- On a **stochastic** network, setting `block_size` also switches the noise to + per-block generation, which *reseeds* the realization relative to the + monolithic single draw. To keep this a clean checkpointing benchmark (common + random numbers, bit-exact across `block_size`), we inject one fixed noise + tensor so the block path uses it verbatim instead of streaming. The + streaming-noise memory saving is a separate axis; see + [Streaming Reductions](streaming_reductions.qmd). +- The same `block_size` is also the grain for online `reduce` statistics (e.g. + streamed FC), not covered here. + ::: {.callout-note} ## Scope and limitations - **Native solvers only.** `DiffraxSolver` is not affected; Diffrax exposes its own `RecursiveCheckpointAdjoint` for adaptive ODE solves, but it does not support delays. -- **No effect when `checkpoint_every is None`.** The call site falls through to +- **No effect when `block_size is None`.** The call site falls through to the original `jax.lax.scan(op, state0, scan_inputs)` line. The default behaviour is bit-exact with prior versions. - **Forward is unaffected by memory savings.** Forward simulations do not retain step activations regardless; checkpointing only matters when you take a gradient. +- **SDE noise is held fixed here.** With a stochastic network, `block_size` + streams (reseeds) the noise; this benchmark injects a fixed noise tensor so + every config integrates the same path and the checkpointed gradient stays + bit-exact to the uncheckpointed one. ::: ```{python} @@ -183,7 +200,7 @@ top of the dynamics state, the noise tensor, and the auxiliary tape. ## Benchmark We benchmark forward time, gradient time, and (where the backend supports it) -peak device memory across a sweep of `checkpoint_every` values. The sweep +peak device memory across a sweep of `block_size` values. The sweep covers: - `None` — the default, single `jax.lax.scan`. Reference for performance. @@ -204,11 +221,21 @@ covers: # is a clean divisor of n_steps (no tail). Most other values do not divide # n_steps exactly and therefore exercise the main-scan + tail-scan path, # which matters for the memory story — see "Reading the memory curve". -CHECKPOINT_VALUES = [None, 32, 128, 256, 512, 1024, 2048, 8192, 30000, N_STEPS] +BLOCK_SIZE_VALUES = [None, 32, 128, 256, 512, 1024, 2048, 8192, 30000, N_STEPS] N_FORWARD_RUNS = 3 N_GRADIENT_RUNS = 3 G_INIT = jnp.asarray(0.5) +# Fixed noise realization (common random numbers). Injecting this into the +# config makes `block_size` do pure gradient checkpointing rather than per-block +# streaming: every config integrates the same noise path, so the checkpointed +# gradient stays bit-exact to the uncheckpointed one and the benchmark isolates +# the activation-tape effect. Shape is [n_steps, n_noise_states, n_nodes]. +n_noise_states = len(network.noise._state_indices) +FIXED_NOISE = jax.random.normal( + network.noise.key, (N_STEPS, n_noise_states, n_nodes) +) + class RSSPeakMonitor: """Context manager that records peak process RSS during the with-block. @@ -276,11 +303,14 @@ class RSSPeakMonitor: self._stop.wait(self.sample_interval) -def benchmark_one(checkpoint_every, fc_target): +def benchmark_one(block_size, fc_target): """Time forward + gradient, capture peak RSS during gradient, and return the gradient value for cross-check.""" - solver = Heun(checkpoint_every=checkpoint_every) + solver = Heun(block_size=block_size) solve_fn, state = prepare(network, solver, t0=0.0, t1=T1, dt=DT) + # Inject the fixed noise so block_size does pure checkpointing (no per-block + # streaming / reseed); all configs then share the same realization. + state._internal.noise_samples = FIXED_NOISE solve_fn = jax.jit(solve_fn) def loss(G): @@ -338,12 +368,12 @@ def benchmark_one(checkpoint_every, fc_target): } -@cache("checkpoint_sweep", redo=False) +@cache("block_size_sweep", redo=True) def run_sweep(): results = {} - for k in CHECKPOINT_VALUES: + for k in BLOCK_SIZE_VALUES: label = "None" if k is None else str(k) - print(f"checkpoint_every = {label} ...", flush=True) + print(f"block_size = {label} ...", flush=True) results[label] = benchmark_one(k, fc_target) gc.collect() return results @@ -356,7 +386,7 @@ sweep_results = run_sweep() ```{python} #| label: fig-checkpoint-benchmark -#| fig-cap: "**Gradient checkpointing benchmark.** Top row — *time*. Top-left: Forward and gradient wall time as a function of `checkpoint_every`; dotted horizontals mark the `None` baseline, dashed vertical marks `√n_steps`. Top-right: Per-call gradient-to-forward ratio. Bottom row — *memory*. Bottom-left: Peak RSS delta during a gradient call vs `checkpoint_every`, showing the `O(n_steps/K + K)` minimum near `√n_steps`. Bottom-right: Memory–time Pareto, with the `None` star at the low-time / high-memory extreme and checkpointed points tracing the front." +#| fig-cap: "**Gradient checkpointing benchmark.** Top row — *time*. Top-left: Forward and gradient wall time vs `block_size` on a shared log y-axis (the two sit about a decade apart); dashed horizontals mark each curve's `None` baseline, the dashed vertical marks `√n_steps`. Top-right: Per-call gradient-to-forward ratio. Bottom row — *memory*. Bottom-left: Peak RSS delta during a gradient call vs `block_size` (linear y), showing the `O(n_steps/K + K)` minimum near `√n_steps`. Bottom-right: Memory–time Pareto (front and `None` points labelled), with the `None` star at the low-time / high-memory extreme and checkpointed points tracing the front." #| echo: true #| code-fold: true #| code-summary: "Plotting code" @@ -364,7 +394,7 @@ sweep_results = run_sweep() baseline = sweep_results["None"] sqrt_n = np.sqrt(N_STEPS) -# K-axis panels drop "None" — it has no x-coordinate on a checkpoint_every +# K-axis panels drop "None" — it has no x-coordinate on a block_size # axis, only a horizontal-reference role. The Pareto panel keeps it as a # distinct star marker because its axes are (time, memory) and there is no # overlap risk. @@ -430,16 +460,13 @@ with plt.rc_context({ # === Top row: time === - # --- Top-left: time vs checkpoint_every (twin y-axes) --- - # Forward and gradient times are an order of magnitude apart, so a - # shared y-axis would collapse the forward curve to a flat line near - # the bottom. Twin axes let each curve fill its own range; we use - # linear scaling on both because each axis spans well under a decade. - # Axis spines and tick colours are tinted to indicate which curve goes - # with which side. + # --- Top-left: time vs block_size (single log y-axis) --- + # Forward (~0.1 s) and gradient (~1 s) are about a decade apart, so a single + # log y-axis separates them cleanly and keeps each None baseline next to its + # own curve without the two dashed references overlapping (which a linear + # twin-axis layout did). The fine overhead detail lives in the panel to the + # right (grad / forward ratio). ax = axes[0, 0] - ax_g = ax.twinx() - fwd_color = "steelblue" grad_color = "firebrick" @@ -447,30 +474,17 @@ with plt.rc_context({ label="forward", lw=1.8, markersize=7, capsize=3) ax.axhline(baseline["fwd_mean"], color=fwd_color, linestyle="dashed", alpha=0.7, label="forward (None)") - - ax_g.errorbar(xs, grad, yerr=grad_err, marker="s", color=grad_color, - label="gradient", lw=1.8, markersize=7, capsize=3) - ax_g.axhline(baseline["grad_mean"], color=grad_color, linestyle="dashed", - alpha=0.7, label="gradient (None)") + ax.errorbar(xs, grad, yerr=grad_err, marker="s", color=grad_color, + label="gradient", lw=1.8, markersize=7, capsize=3) + ax.axhline(baseline["grad_mean"], color=grad_color, linestyle="dashed", + alpha=0.7, label="gradient (None)") ax.set_xscale("log") - ax.set_xlabel("checkpoint_every") - ax.set_ylabel("forward wall time (s)", color=fwd_color) - ax_g.set_ylabel("gradient wall time (s)", color=grad_color) - ax.set_title("Time vs checkpoint_every") - - # Tint the axis ticks and spines to match the data they describe. - ax.tick_params(axis="y", colors=fwd_color) - ax.spines["left"].set_color(fwd_color) - ax_g.tick_params(axis="y", colors=grad_color) - ax_g.spines["right"].set_color(grad_color) - ax_g.spines["left"].set_visible(False) - - # Combined legend from both axes. - h1, l1 = ax.get_legend_handles_labels() - h2, l2 = ax_g.get_legend_handles_labels() - ax.legend(h1 + h2, l1 + l2, loc="best", framealpha=0.9) - + ax.set_yscale("log") + ax.set_xlabel("block_size") + ax.set_ylabel("wall time (s)") + ax.set_title("Time vs block_size") + ax.legend(loc="best", framealpha=0.9, ncol=2) ax.grid(alpha=0.3, which="both") _mark_sqrt_n(ax) @@ -483,7 +497,7 @@ with plt.rc_context({ ax.axhline(baseline_ratio, color="darkgreen", linestyle="dashed", alpha=0.7, label=f"None baseline ({baseline_ratio:.2f}×)") ax.set_xscale("log") - ax.set_xlabel("checkpoint_every") + ax.set_xlabel("block_size") ax.set_ylabel("grad / forward") ax.set_title("Gradient overhead") ax.grid(alpha=0.3, which="both") @@ -492,18 +506,19 @@ with plt.rc_context({ # === Bottom row: memory === - # --- Bottom-left: memory vs checkpoint_every --- + # --- Bottom-left: memory vs block_size --- ax = axes[1, 0] if has_memory: ax.plot(xs, mem_ck_mb, marker="D", color="purple", lw=1.8, markersize=8, label="peak RSS delta during grad") ax.axhline(baseline["peak_bytes_delta"] / 1e6, color="purple", linestyle="dashed", alpha=0.7, label="None baseline") - ax.set_xscale("log") - ax.set_yscale("log") - ax.set_xlabel("checkpoint_every") + ax.set_xscale("log") # block_size spans decades; y stays linear so the + # U-shape and the absolute MB differences read directly. + ax.set_ylim(bottom=0) + ax.set_xlabel("block_size") ax.set_ylabel("peak RSS delta during grad (MB)") - ax.set_title("Memory vs checkpoint_every") + ax.set_title("Memory vs block_size") ax.grid(alpha=0.3, which="both") ax.legend(loc="best", framealpha=0.9) _mark_sqrt_n(ax) @@ -515,7 +530,7 @@ with plt.rc_context({ bbox=dict(boxstyle="round,pad=0.5", facecolor="lightyellow")) ax.set_xticks([]) ax.set_yticks([]) - ax.set_title("Memory vs checkpoint_every (unavailable)") + ax.set_title("Memory vs block_size (unavailable)") # --- Bottom-right: memory–time Pareto --- # Two cleanup ideas vs the old "connect-by-time" line, which crossed @@ -557,8 +572,13 @@ with plt.rc_context({ else: ax.scatter([x], [y], s=55, facecolor="white", edgecolor="purple", linewidth=1.3, zorder=2) - ax.annotate(l, (x, y), textcoords="offset points", - xytext=(8, 6), fontsize=11) + # Label only the front points and None: the dominated points cluster + # near the front and their labels collide. Alternate the vertical + # offset to further reduce overlap among the labelled ones. + if on_front or l == "None": + dy = 8 if (i % 2 == 0) else -12 + ax.annotate(l, (x, y), textcoords="offset points", + xytext=(8, dy), fontsize=10) ax.set_xlabel("gradient time (s)") ax.set_ylabel("peak RSS delta during grad (MB)") @@ -599,7 +619,7 @@ with plt.rc_context({ ## Reading the Memory Curve -The memory-vs-`checkpoint_every` panel (bottom-left) shows the classical analysis. Peak gradient +The memory-vs-`block_size` panel (bottom-left) shows the classical analysis. Peak gradient memory scales as $$ \mathrm{peak\,memory} \;\approx\; \underbrace{\frac{n_\text{steps}}{K} \cdot c_\text{outer}}_{\text{block-boundary tape}} \;+\; \underbrace{K \cdot c_\text{inner}}_{\text{per-block inner tape during backward}} $$ @@ -673,7 +693,7 @@ from the saved tape. #| echo: true baseline = sweep_results["None"] -print(f"{'checkpoint_every':<20} {'loss':<22} {'grad':<22} {'|Δgrad / grad|':<18}") +print(f"{'block_size':<20} {'loss':<22} {'grad':<22} {'|Δgrad / grad|':<18}") print("-" * 82) for label, res in sweep_results.items(): rel = abs((res["grad_value"] - baseline["grad_value"]) / baseline["grad_value"]) @@ -700,7 +720,7 @@ or the device-memory delta if `jax.devices()[0].memory_stats()` is available baseline = sweep_results["None"] header = ( - f"{'ckpt_every':<12} " + f"{'block_size':<12} " f"{'fwd_s':<14} " f"{'grad_s':<14} " f"{'grad/fwd':<10} " @@ -750,7 +770,7 @@ print(f"# device: {jax.devices()[0].platform} jax {jax.__version__}") ## No-Regression Check -Because `checkpoint_every=None` selects the original `jax.lax.scan` call site +Because `block_size=None` selects the original `jax.lax.scan` call site verbatim (a literal if-branch), forward and gradient times for the default must be **within timing noise** of the previous non-checkpointed implementation. The benchmark above implicitly verifies this: the `None` row should be @@ -758,7 +778,7 @@ statistically indistinguishable from any prior measurement of the unchecked path. `K = n_steps` is **not** equivalent to `None` — it still wraps the (single) inner scan in `jax.checkpoint`, so the backward pass recomputes the entire forward once, costing roughly 1.3× the `None` -gradient time. Only `checkpoint_every=None` skips checkpointing entirely. +gradient time. Only `block_size=None` skips checkpointing entirely. ## Practical Guidance @@ -770,11 +790,11 @@ from tvboptim.experimental.network_dynamics.solvers import Heun solver = Heun() # Memory-optimal default when gradients no longer fit in memory. -solver = Heun(checkpoint_every=int(math.sqrt(n_steps))) +solver = Heun(block_size=int(math.sqrt(n_steps))) # Aggressive: minimal memory, maximal recompute. Use only if the sqrt # default still OOMs. -solver = Heun(checkpoint_every=64) +solver = Heun(block_size=64) ``` The same field works on `Euler`, `Heun`, `RungeKutta4`, and any diff --git a/docs/advanced/streaming_reductions.qmd b/docs/advanced/streaming_reductions.qmd new file mode 100644 index 0000000..1c5e1c5 --- /dev/null +++ b/docs/advanced/streaming_reductions.qmd @@ -0,0 +1,424 @@ +--- +title: "Streaming Reductions for Long Forward Simulations" +subtitle: "Computing BOLD / FC From Long Rollouts Without Holding the Trajectory" +format: + html: + code-fold: false + toc: true + echo: false + embed-resources: true + fig-width: 8 + out-width: "100%" +jupyter: python3 +execute: + cache: true +--- + +Try this notebook interactively: + +[Download .ipynb](https://github.com/virtual-twin/tvboptim/blob/main/docs/advanced/streaming_reductions.ipynb){.btn .btn-primary download="streaming_reductions.ipynb"} +[Download .qmd](streaming_reductions.qmd){.btn .btn-secondary download="streaming_reductions.qmd"} +[Open in Colab](https://colab.research.google.com/github/virtual-twin/tvboptim/blob/main/docs/advanced/streaming_reductions.ipynb){.btn .btn-warning target="_blank"} + +## Introduction + +[Gradient Checkpointing](gradient_checkpointing.qmd) tackles *backward*-pass +memory: the activation tape that a gradient needs. This notebook tackles the +other side of the same problem, *forward*-pass memory, which matters even +when you never take a gradient. + +A long forward rollout holds two `O(n_steps)` tensors in memory: + +1. the stacked output **trajectory**, `[n_steps, n_voi, n_nodes]`, and +2. for a stochastic network, the pre-sampled **noise tensor**, the same size. + +But most of what we actually want from a long run is a *reduced* statistic: +functional connectivity (a long-time covariance), a BOLD signal (a kernel +convolution / hemodynamic ODE), a temporal average. Each is a **reduction** of +the trajectory: a `fold` that consumes the time series and keeps only a small +running aggregate, the streaming counterpart of a one-shot +`compute_fc(trajectory)`. The native solver exposes this through the per-call +`reduce` kwarg, which folds the trajectory **block-by-block** into the statistic +instead of stacking it, so resident memory scales with `block_size`, not +`n_steps`. (If you come from TVB, this is the monitor idea generalized to an +arbitrary online reduction.) + +```{python} +#| output: false +#| echo: false +try: + import google.colab + print("Running in Google Colab - installing dependencies...") + !pip install -q tvboptim + print("✓ Dependencies installed!") +except ImportError: + pass +``` + +The motivation is throughput, not differentiation: a forward-only statistic that +fits in a few MB instead of hundreds lets you **pack many independent +simulations onto one GPU** (parameter sweeps, posterior sampling, seed +ensembles) where the stacked trajectories would not fit. The same machinery +composes with `grad_horizon` for a *differentiable* FC over a long rollout, the +window bounding the gradient horizon for stability and `block_size` bounding the +backward memory by the checkpointing model (not the forward's `O(block_size)`, +see Practical Guidance). That combination is covered in [Truncated +Backpropagation](../workflows/TBPTT.qmd). Here we stay forward-only. + +We compare three ways to get a BOLD signal from the same long simulation: + +| variant | forward | output | resident memory | +|---|---|---|---| +| **classical** | monolithic `Heun()` | full trajectory, monitor post-hoc | trajectory **+** noise tensor | +| **blocked** | `Heun(block_size=K)` | full trajectory, monitor post-hoc | trajectory only | +| **streaming** | `Heun(block_size=K)` | `reduce=streaming_hrf_bold` | one block | + +The point of the middle row is that **blocking the simulation alone does not +bound the output memory.** It streams the noise (removing tensor 2), but the +trajectory is still stacked. Only `reduce` removes the trajectory too. + +::: {.callout-note} +## Scope and requirements +- **Native solvers only.** The Diffrax dispatch rejects `reduce`. +- **`streaming_hrf_bold` needs a `SubSampling` downsample** (a uniform integer + stride; `TemporalAverage`'s float-rounded windows are not faithfully + streamable) and **`block_size` / `n_steps` multiples of the BOLD period in raw + steps** so each block emits a whole number of TR samples. +- **Noise realizations differ by construction.** A blocked SDE run draws its + noise per block (`fold_in`), reseeding relative to the monolithic single draw. + So *blocked* and *streaming* share one realization (their BOLD is pointwise + equal, the equivalence assertion below), while *classical* is a different + realization, shown for the memory/runtime axis only. See the + [solver reference](../network_dynamics/solvers.qmd) for the streaming-noise + and `reduce` design. +::: + +```{python} +#| output: false +#| code-fold: true +#| code-summary: "Environment setup and imports" +#| echo: true + +import time +import gc +import threading +import numpy as np +import matplotlib.pyplot as plt +import jax +import jax.numpy as jnp + +try: + import psutil + _HAS_PSUTIL = True +except ImportError: + _HAS_PSUTIL = False + +jax.config.update("jax_enable_x64", True) + +from tvboptim.experimental.network_dynamics import Network, solve +from tvboptim.experimental.network_dynamics.dynamics.tvb import ReducedWongWang +from tvboptim.experimental.network_dynamics.coupling import DelayedLinearCoupling +from tvboptim.experimental.network_dynamics.graph import DenseDelayGraph +from tvboptim.experimental.network_dynamics.noise import AdditiveNoise +from tvboptim.experimental.network_dynamics.solvers import Heun +from tvboptim.observations.tvb_monitors import HRFBold, SubSampling, streaming_hrf_bold +from tvboptim.data import load_structural_connectivity +from tvboptim.utils import set_cache_path, cache + +set_cache_path("./streaming_reductions") +``` + +## Workload: RWW + Delays + BOLD + +The same delayed Reduced Wong-Wang network as the [gradient checkpointing +notebook](gradient_checkpointing.qmd): `dk_average` structural connectivity (68 +regions), tract lengths converted to delays at 4 mm/ms. Here we run it +**long and forward-only** to a BOLD signal. + +```{python} +#| echo: true +#| output: false + +DT = 1.0 # ms +T1 = 120_000.0 # ms (120 s), a long statistical rollout +N_STEPS = int(T1 / DT) # 120_000 steps +CONDUCTION_SPEED = 4.0 # mm/ms + +weights, lengths, region_labels = load_structural_connectivity(name="dk_average") +weights = weights / np.max(weights) +delays = jnp.asarray(lengths / CONDUCTION_SPEED) +n_nodes = weights.shape[0] + +graph = DenseDelayGraph( + weights=jnp.asarray(weights), delays=delays, region_labels=region_labels +) +dynamics = ReducedWongWang(w=0.5, I_o=0.32, INITIAL_STATE=(0.3,)) +coupling = DelayedLinearCoupling(incoming_states="S", G=0.5, buffer_strategy="roll") +noise = AdditiveNoise(sigma=0.00283, apply_to="S", key=jax.random.key(0)) +network = Network( + dynamics=dynamics, coupling={"delayed": coupling}, graph=graph, noise=noise +) + +# BOLD monitor with a SubSampling downsample (required for streaming). With +# DT=1.0: period_in_steps = (downsample_period/DT) * (period/downsample_period) +# = period/DT = 1000 raw steps per BOLD sample, +# so block_size and N_STEPS must be multiples of 1000. +monitor = HRFBold( + period=1000.0, # TR = 1 s + downsample_period=10.0, + downsample=SubSampling(period=10.0), + voi=0, +) +BLOCK_SIZE = 2000 +PERIOD_IN_STEPS = 1000 +assert N_STEPS % PERIOD_IN_STEPS == 0 and BLOCK_SIZE % PERIOD_IN_STEPS == 0 + +traj_mb = N_STEPS * n_nodes * 8 / 1e6 +print(f"n_nodes={n_nodes} n_steps={N_STEPS} block_size={BLOCK_SIZE}") +print(f"full trajectory ~{traj_mb:.0f} MB; monolithic noise tensor ~the same again") +``` + +## Benchmark + +`RSSPeakMonitor` records the peak process-RSS delta during a call (a pragmatic +CPU proxy; on GPU/TPU read `jax.devices()[0].memory_stats()` instead). We +measure compile time, peak RSS, and best wall time for each of the three +variants, then compare the BOLD outputs. + +```{python} +#| echo: true +#| output: false +#| code-fold: true +#| code-summary: "Benchmark setup" + + +class RSSPeakMonitor: + """Record peak process RSS over the with-block (peak minus entry baseline). + + A background thread polls ``psutil`` RSS at ``sample_interval`` and tracks + the max. On CPU, XLA buffers live in process RSS, so transient activations + and the stacked trajectory show up here. ``None`` if psutil is missing.""" + + def __init__(self, sample_interval=0.02): + self.sample_interval = sample_interval + self.peak_delta_bytes = None + + def __enter__(self): + if not _HAS_PSUTIL: + return self + self._process = psutil.Process() + self._baseline = self._process.memory_info().rss + self._peak = self._baseline + self._stop = threading.Event() + self._thread = threading.Thread(target=self._sample, daemon=True) + self._thread.start() + return self + + def __exit__(self, *exc): + if not _HAS_PSUTIL: + return False + self._stop.set() + self._thread.join() + self.peak_delta_bytes = max(0, self._peak - self._baseline) + return False + + def _sample(self): + while not self._stop.is_set(): + try: + self._peak = max(self._peak, self._process.memory_info().rss) + except Exception: + break + self._stop.wait(self.sample_interval) + + +# Variants (2) and (3) share one noise realization (same block_size forward); +# (1) is the vanilla monolithic forward, a different draw. +blocked_solver = Heun(block_size=BLOCK_SIZE) +monolithic_solver = Heun() + + +def classical(): + sol = solve(network, monolithic_solver, t0=0.0, t1=T1, dt=DT) + return monitor(sol).ys + + +def blocked(): + sol = solve(network, blocked_solver, t0=0.0, t1=T1, dt=DT) + return monitor(sol).ys + + +def streaming(): + return solve( + network, blocked_solver, t0=0.0, t1=T1, dt=DT, + reduce=streaming_hrf_bold(monitor, DT), + ) + + +def measure(fn, n_rep=3): + t0 = time.perf_counter() + jax.block_until_ready(fn()) # compile + first call + compile_s = time.perf_counter() - t0 + gc.collect() + mon = RSSPeakMonitor() + with mon: + out = fn() + jax.block_until_ready(out) + best = float("inf") + for _ in range(n_rep): + t = time.perf_counter() + jax.block_until_ready(fn()) + best = min(best, time.perf_counter() - t) + return { + "compile_s": compile_s, + "peak_mb": (mon.peak_delta_bytes / 1e6) if mon.peak_delta_bytes else None, + "best_s": best, + "bold": np.asarray(out), + } + + +@cache("streaming_bold_sweep", redo=True) +def run_sweep(): + results = {} + for name, fn in (("classical", classical), ("blocked", blocked), + ("streaming", streaming)): + print(f"{name} ...", flush=True) + results[name] = measure(fn) + gc.collect() + return results + + +results = run_sweep() +``` + +## Results + +```{python} +#| label: fig-streaming-memory +#| fig-cap: "**Streaming reductions: the forward-memory ladder.** Left: peak process-RSS delta for the three variants (log y). Classical holds the trajectory *and* the noise tensor; blocking the forward streams the noise away but still stacks the trajectory; only the streaming `reduce` bounds memory to one block. Right: the BOLD signal of one region. Blocked and streaming overlap exactly (shared noise); classical is a different but statistically equivalent realization." +#| echo: true +#| code-fold: true +#| code-summary: "Plotting code" + +names = ["classical", "blocked", "streaming"] +colors = {"classical": "firebrick", "blocked": "darkorange", "streaming": "seagreen"} + +fig, (ax_mem, ax_bold) = plt.subplots(1, 2, figsize=(13, 5)) + +# --- Left: peak memory ladder --- +mems = [results[n]["peak_mb"] for n in names] +if all(m is not None for m in mems): + bars = ax_mem.bar(names, mems, color=[colors[n] for n in names]) + ax_mem.set_yscale("log") + ax_mem.set_ylabel("peak RSS delta (MB)") + ax_mem.set_title("Forward memory") + for b, m in zip(bars, mems): + ax_mem.text(b.get_x() + b.get_width() / 2, m, f"{m:.0f}", + ha="center", va="bottom", fontsize=11) +else: + ax_mem.text(0.5, 0.5, "Peak memory unavailable\n(psutil not installed)", + transform=ax_mem.transAxes, ha="center", va="center") +ax_mem.grid(alpha=0.3, axis="y", which="both") + +# --- Right: one region's BOLD time course --- +node = 0 +for n in names: + y = results[n]["bold"][:, 0, node] + t = (np.arange(len(y)) + 1) * monitor.period / 1000.0 # s + ax_bold.plot(t, y, color=colors[n], lw=1.6, + alpha=0.9 if n != "blocked" else 0.6, + ls="--" if n == "streaming" else "-", label=n) +ax_bold.set_xlabel("time (s)") +ax_bold.set_ylabel(f"BOLD (region {node})") +ax_bold.set_title("BOLD signal") +ax_bold.legend(framealpha=0.9) +ax_bold.grid(alpha=0.3) + +plt.tight_layout() +plt.show() +``` + +## Equivalence + +Blocked and streaming share a noise realization, so their BOLD must agree to FFT +float-reassociation error. Classical is a different draw, compared on amplitude +only. + +```{python} +#| echo: true + +bold_classical = results["classical"]["bold"] +bold_blocked = results["blocked"]["bold"] +bold_streaming = results["streaming"]["bold"] + +max_abs = float(np.max(np.abs(bold_blocked - bold_streaming))) +scale = float(np.max(np.abs(bold_blocked))) + 1e-12 +print(f"BOLD shape {bold_streaming.shape} ({bold_streaming.shape[0]} samples)\n") +print(f"blocked vs streaming (shared noise): max abs {max_abs:.2e} " + f"(rel {max_abs / scale:.2e}) " + f"{'MATCH' if max_abs / scale < 1e-4 else 'MISMATCH'}") +print(f"classical (different noise draw): std {bold_classical.std():.3e} vs " + f"streaming std {bold_streaming.std():.3e} (statistically comparable)") +``` + +## Summary Table + +```{python} +#| echo: true +#| code-fold: true +#| code-summary: "Table code" + +base = results["classical"]["peak_mb"] +header = f"{'variant':<12} {'peak_MB':<10} {'vs classical':<14} {'best_s':<10} {'compile_s':<10}" +print(header) +print("-" * len(header)) +for n in names: + r = results[n] + peak = f"{r['peak_mb']:.1f}" if r["peak_mb"] is not None else "NA" + rel = f"{r['peak_mb'] / base:.2f}x" if (r["peak_mb"] and base) else "NA" + print(f"{n:<12} {peak:<10} {rel:<14} {r['best_s']:<10.3f} {r['compile_s']:<10.3f}") + +print() +print(f"# workload: n_nodes={n_nodes}, n_steps={N_STEPS}, dt={DT}, T={T1/1000:.0f}s, " + f"block_size={BLOCK_SIZE}") +print(f"# device: {jax.devices()[0].platform} jax {jax.__version__}") +``` + +## Practical Guidance + +```python +from tvboptim.experimental.network_dynamics import solve +from tvboptim.experimental.network_dynamics.solvers import Heun +from tvboptim.observations import welford_cov +from tvboptim.observations.tvb_monitors import HRFBold, SubSampling, streaming_hrf_bold + +# Streamed functional connectivity: returns the FC matrix, never stacks ys. +fc = solve(network, Heun(block_size=2000), t0=0.0, t1=120_000.0, dt=1.0, + reduce=welford_cov(s_var=0)) + +# Streamed BOLD: SubSampling downsample, block_size a multiple of period/dt. +monitor = HRFBold(period=1000.0, downsample_period=10.0, + downsample=SubSampling(period=10.0), voi=0) +bold = solve(network, Heun(block_size=2000), t0=0.0, t1=120_000.0, dt=1.0, + reduce=streaming_hrf_bold(monitor, dt=1.0)) +``` + +Rules of thumb: + +- **Reach for `reduce` when the trajectory itself is the binding memory cost**, + like long statistical rollouts you pack many of onto a GPU. For a single run + that fits, the stacked trajectory is simpler. +- **`block_size` trades throughput for memory.** Larger blocks amortize the + per-block work toward the monolithic forward speed; very small blocks serialize + it. Pick the smallest block whose chunk comfortably fits. +- **Differentiating the FC does not keep the forward's `O(block_size)` win.** + Reverse mode retains every block's boundary carry (dynamics state, delay + history buffer, accumulator), so backward memory follows the + [checkpointing](gradient_checkpointing.qmd) model + `O(n_steps/block_size + block_size)`, U-shaped in `block_size`, not constant in + `n_steps`. `block_size` bounds that memory; `grad_horizon` is separate, it + bounds the gradient *horizon* (stability), not memory. For a long + differentiable FC set both; see + [Truncated Backpropagation](../workflows/TBPTT.qmd). +- **`welford_cov` is a memory win, not a flop win**: post-hoc `compute_fc` is one + GEMM; the streamed reducer is `n_blocks` batched merges of the same total cost. +``` diff --git a/docs/index.qmd b/docs/index.qmd index 03ac2b3..a73f967 100644 --- a/docs/index.qmd +++ b/docs/index.qmd @@ -216,6 +216,7 @@ For direct model specification and optimization workflows: - [Network Dynamics Introduction](./network_dynamics/network_dynamics.qmd) - Overview of the framework architecture - [Complete Optimization Workflows](./network_dynamics/network_dynamics.qmd#complete-optimization-workflows) - End-to-end examples: - [Reduced Wong-Wang BOLD FC Optimization](./workflows/RWW.qmd) - Fitting functional connectivity from fMRI + - [Truncated Backpropagation for Long Simulations](./workflows/TBPTT.qmd) - Bounding the gradient horizon with `grad_horizon` - [Jansen-Rit MEG Peak Frequency Gradient](./workflows/JR.qmd) - Reproducing spatial frequency patterns in MEG data - [Excitation Inhibition Balance Tuning](./workflows/EI_Tuning.qmd) - Connectivity scale optimization with and without automatic differentiation diff --git a/docs/network_dynamics/solvers.qmd b/docs/network_dynamics/solvers.qmd index 14408a2..274d956 100644 --- a/docs/network_dynamics/solvers.qmd +++ b/docs/network_dynamics/solvers.qmd @@ -259,28 +259,64 @@ This is the same default TVB uses, made explicit. See full performance and accuracy analysis. ::: -### Gradient Checkpointing +### Long Simulations: Memory, Gradient Horizon, Online Statistics -Differentiating through a long simulation can run out of memory on the backward -pass: `jax.lax.scan` saves every step's carry, and for delayed couplings that -carry includes the history buffer. Every native solver takes a `checkpoint_every` -argument to trade recompute for memory: +Long simulations stress differentiation in two ways: the backward pass can run +out of memory, and the gradient itself can explode over a long horizon. Native +solvers expose three back-compatible knobs (all off by default) for this. + +**`block_size` — gradient checkpointing (and the streaming block unit).** +`jax.lax.scan` saves every step's carry for the backward pass, and for delayed +couplings that carry includes the history buffer, so backward memory grows with +`n_steps`. `block_size=K` splits the scan into `jax.checkpoint`-wrapped blocks of +`K` steps, trading one extra forward recompute for `O(n_steps/K + K)` backward +memory (minimum near `K ≈ sqrt(n_steps)`): ```python from tvboptim.experimental.network_dynamics.solvers import Heun -# Default: single scan, every step's carry kept for the backward pass -solver = Heun() - -# Checkpoint blocks of K steps: backward memory drops from O(n_steps) -# to O(n_steps / K + K), at the cost of one extra forward recompute -solver = Heun(checkpoint_every=200) +solver = Heun() # default: single scan, full backward memory +solver = Heun(block_size=200) # checkpoint blocks of 200 steps ``` -This only affects gradient computation; forward-only runs are untouched. Memory -is minimized near `K ≈ sqrt(n_steps)`. See +`block_size` is the one block unit for the streaming features below (it was +formerly named `checkpoint_every`). On a **stochastic** network it also switches +the noise to per-block generation, which reseeds the realization relative to the +monolithic single draw; this is intended and lets long SDE runs avoid holding the +full `O(n_steps)` noise tensor. Forward-only runs are untouched. See [Gradient Checkpointing](../advanced/gradient_checkpointing.qmd) for the memory -model, the `K` sweep, and the buffer-strategy caveats. +model and the `K` sweep. + +**`grad_horizon` — truncated backpropagation (gradient horizon).** Sever the carry +gradient every `W` steps so the backward sensitivity stops amplifying past one +window. The forward is bit-identical (it is a stability knob, not a speed knob); +only the backward changes. When combined with `block_size` it is snapped to a +multiple of it so window and block boundaries align: + +```python +solver = Heun(grad_horizon=2000) # bounded gradient horizon +solver = Heun(grad_horizon=2000, block_size=200) # horizon + bounded memory +``` + +See [Truncated Backpropagation](../workflows/TBPTT.qmd). + +**`reduce` — online statistics (per-call, not a solver attribute).** A +`reduce=(init, update, finalize)` kwarg on `solve` / `prepare` folds the +trajectory block-by-block into a running statistic instead of stacking it, so the +output memory is independent of `n_steps`. `welford_cov` streams functional +connectivity; `streaming_hrf_bold` streams a BOLD signal: + +```python +from tvboptim.observations import welford_cov + +# Online FC over a long, memory-bounded SDE rollout; returns the FC matrix. +fc = solve(network, Heun(block_size=2000), t0=0.0, t1=100_000.0, dt=0.1, + reduce=welford_cov(s_var=0)) +``` + +`reduce` is native-only (the Diffrax dispatch rejects it). See +[Streaming Reductions](../advanced/streaming_reductions.qmd) for the streaming +and `reduce` design, with a memory and runtime comparison. --- diff --git a/docs/workflows/RWW.qmd b/docs/workflows/RWW.qmd index 32f1e8a..c5f26a1 100644 --- a/docs/workflows/RWW.qmd +++ b/docs/workflows/RWW.qmd @@ -67,9 +67,6 @@ import copy import optax from scipy import io -# Jax enable x64 -jax.config.update("jax_enable_x64", True) - # Import from tvboptim from tvboptim.types import Parameter, Space, GridAxis from tvboptim.types.stateutils import show_parameters @@ -175,8 +172,8 @@ We combine the RWW dynamics with structural connectivity to create a whole-brain # Create network components graph = DenseGraph(weights, region_labels=region_labels) -dynamics = ReducedWongWang(w=0.5, I_o=0.32, INITIAL_STATE=(0.3,)) -coupling = FastLinearCoupling(local_states=["S"], G=0.5) +dynamics = ReducedWongWang(w=0.3, I_o=0.32, INITIAL_STATE=(0.3,)) +coupling = FastLinearCoupling(local_states=["S"], G=0.15) noise = AdditiveNoise(sigma=0.00283, apply_to="S") # Assemble the network @@ -195,7 +192,7 @@ We prepare the network for simulation and run an initial transient to reach a qu ```{python} #| echo: true # Prepare simulation: compile model and initialize state -t1 = 120_000 # Total simulation duration (ms) - 2 minutes +t1 = 90_000 # Total simulation duration (ms) - 1.5 minutes dt = 4.0 # Integration timestep (ms) model, state = prepare(network, Heun(), t1=t1, dt=dt) @@ -350,7 +347,7 @@ Before optimization, we explore how the model parameters affect FC quality. We s #| output: false # Create grid for parameter exploration -n = 32 +n = 16 # Set up parameter axes for exploration grid_state = copy.deepcopy(state) @@ -426,8 +423,8 @@ cb = MultiCallback([ @cache("optimize", redo=False) def optimize(): - opt = OptaxOptimizer(loss, optax.adam(0.01, b2=0.9999), callback=cb) - fitted_state, fitting_data = opt.run(state, max_steps=300) + opt = OptaxOptimizer(loss, optax.adam(0.01), callback=cb) + fitted_state, fitting_data = opt.run(state, max_steps=100) return fitted_state, fitting_data fitted_state, fitting_data = optimize() @@ -500,7 +497,7 @@ plt.tight_layout() ## Heterogeneous Optimization -Global parameters (same for all regions) may not capture region-specific variations needed for optimal FC fit. We now make parameters heterogeneous: each brain region gets its own `w` and `I_o` values, while keeping `G` global. +Global parameters (same for all regions) may not capture region-specific variations needed for optimal FC fit. We now make parameters heterogeneous: each brain region gets its own `w` value, while keeping `G` global. ```{python} #| echo: true @@ -510,10 +507,6 @@ fitted_state_het = copy.deepcopy(fitted_state) # Make w regional (one value per node) fitted_state_het.dynamics.w.shape = (n_nodes,) -# Also make I_o regional and mark as optimizable -fitted_state_het.dynamics.I_o = Parameter(fitted_state_het.dynamics.I_o) -fitted_state_het.dynamics.I_o.shape = (n_nodes,) - # Keep global coupling fixed at optimized value fitted_state_het.coupling.instant.G = fitted_state_het.coupling.instant.G.value @@ -526,8 +519,8 @@ show_parameters(fitted_state_het) @cache("optimize_het", redo=False) def optimize_het(): - opt = OptaxOptimizer(loss, optax.adam(0.004, b2=0.999), callback=cb) - fitted_state, fitting_data = opt.run(fitted_state_het, max_steps=200) + opt = OptaxOptimizer(loss, optax.adam(0.005), callback=cb) + fitted_state, fitting_data = opt.run(fitted_state_het, max_steps=100) return fitted_state, fitting_data fitted_state_het, fitting_data_het = optimize_het() @@ -631,7 +624,7 @@ Let's examine the fitted region-specific parameters and their relationship to st ```{python} #| label: fig-fitted-params -#| fig-cap: "**Fitted heterogeneous parameters.** Left: Fitted excitatory recurrence (w) for each region plotted against mean incoming structural connectivity strength. Right: Fitted external input (I_o) vs mean connectivity. Dashed lines show the global optimization values for reference. Regions with stronger structural connections tend to require different parameter values to achieve optimal FC fit, demonstrating the importance of region-specific tuning." +#| fig-cap: "**Fitted heterogeneous parameters.** Fitted excitatory recurrence (w) for each region plotted against mean incoming structural connectivity strength. The dashed line shows the global optimization value for reference. Regions with stronger structural connections tend to require different parameter values to achieve optimal FC fit, demonstrating the importance of region-specific tuning." #| code-fold: true #| code-summary: "Show plotting code" @@ -640,14 +633,12 @@ mean_connectivity = np.mean(weights, axis=1) # Extract fitted regional parameters w_fitted = fitted_state_het.dynamics.w.value.flatten() -I_o_fitted = fitted_state_het.dynamics.I_o.value.flatten() # Get global optimization values for reference w_global = fitted_state.dynamics.w.value -I_o_global = fitted_state.dynamics.I_o # Not optimized in global fit, but initial value # Create figure -fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8.1, 3.24)) +fig, ax1 = plt.subplots(1, 1, figsize=(4.5, 3.24)) # Plot w vs mean connectivity ax1.scatter(mean_connectivity, w_fitted, alpha=0.7, s=30, color='royalblue', edgecolors='k', linewidths=0.5) @@ -658,15 +649,6 @@ ax1.set_title('Regional Excitatory Recurrence Parameters') ax1.legend(loc='best') ax1.grid(True, alpha=0.3) -# Plot I_o vs mean connectivity -ax2.scatter(mean_connectivity, I_o_fitted, alpha=0.7, s=30, color='royalblue', edgecolors='k', linewidths=0.5) -ax2.axhline(I_o_global, color='red', linestyle='--', linewidth=2, label=f'Initial I_o = {I_o_global:.3f}') -ax2.set_xlabel('Mean Incoming Connectivity') -ax2.set_ylabel('Fitted I_o (External Input)') -ax2.set_title('Regional External Input Parameters') -ax2.legend(loc='best') -ax2.grid(True, alpha=0.3) - plt.tight_layout() ``` diff --git a/docs/workflows/TBPTT.qmd b/docs/workflows/TBPTT.qmd new file mode 100644 index 0000000..cf56379 --- /dev/null +++ b/docs/workflows/TBPTT.qmd @@ -0,0 +1,680 @@ +--- +title: "Truncated Backpropagation for Chaotic Brain Networks" +subtitle: "Why long-rollout gradients break, and how the `grad_horizon` knob fixes them" +format: + html: + code-fold: false + toc: true + echo: false + embed-resources: true + fig-width: 8 + out-width: "100%" +jupyter: python3 +execute: + cache: true +--- + +Try this notebook interactively: + +[Download .ipynb](https://github.com/virtual-twin/tvboptim/blob/main/docs/workflows/TBPTT.ipynb){.btn .btn-primary download="TBPTT.ipynb"} +[Download .qmd](TBPTT.qmd){.btn .btn-secondary download="TBPTT.qmd"} +[Open in Colab](https://colab.research.google.com/github/virtual-twin/tvboptim/blob/main/docs/workflows/TBPTT.ipynb){.btn .btn-warning target="_blank"} + +## Introduction + +Fitting resting-state functional connectivity (FC) needs **long** simulations: +FC is a long-time covariance, so a stable estimate requires a long forward +rollout. Differentiating through that rollout is an RNN backward pass, and its +sensitivity grows like $\exp(T\,\Lambda)$ over a window of length $T$, where +$\Lambda$ is the largest Lyapunov exponent of the dynamics. While the network +sits at a fixed point or a limit cycle ($\Lambda \le 0$) this is harmless. Once +a control parameter pushes the network across a bifurcation into **chaos** +($\Lambda > 0$), the same backward pass amplifies without bound and the gradient +becomes numerically useless. + +This notebook makes that failure concrete on a Jansen-Rit whole-brain model, +where the global coupling $G$ drives a bifurcation into chaos. On the stable side +the exact autodiff (AD) gradient of an FC loss with respect to $G$ is correct, +checked against a finite-difference ground truth and the measured Lyapunov +exponent; past the bifurcation it flips sign and blows up. The fix is truncated +backpropagation through time (TBPTT), a single solver knob `grad_horizon` that +keeps the long forward rollout but pulls the gradient back only through a fixed +window. With it, optimizing $G$ to fit empirical FC stays on track where the full +gradient does not. + +```{python} +#| output: false +#| echo: false +# Install dependencies if running in Google Colab +try: + import google.colab + print("Running in Google Colab - installing dependencies...") + !pip install -q tvboptim + print("✓ Dependencies installed!") +except ImportError: + pass # Not in Colab, assume dependencies are available +``` + +```{python} +#| output: false +#| code-fold: true +#| code-summary: "Environment setup and imports" +#| echo: true +import os + +# Mock host devices so the G-sweep can shard across pmap on CPU. Keep this modest: +# each device runs a heavy reverse-mode AD sim and oversubscribing crashes the kernel. +N_DEVICES = 4 +os.environ.setdefault( + "XLA_FLAGS", f"--xla_force_host_platform_device_count={N_DEVICES}" +) + +import numpy as np +import matplotlib.pyplot as plt +from matplotlib.lines import Line2D +from matplotlib.patches import FancyBboxPatch, FancyArrowPatch +import jax +import jax.numpy as jnp +import equinox as eqx +import optax + +# 64-bit precision for reliable gradients and Lyapunov exponents +jax.config.update("jax_enable_x64", True) + +from tvboptim.utils import set_cache_path, cache +from tvboptim.data import ( + load_structural_connectivity, + load_functional_connectivity, +) +from tvboptim.execution import ParallelExecution +from tvboptim.types import GridAxis, Space, Parameter +from tvboptim.experimental.network_dynamics import Network, prepare +from tvboptim.experimental.network_dynamics.dynamics.tvb import JansenRit +from tvboptim.experimental.network_dynamics.coupling import SigmoidalJansenRit +from tvboptim.experimental.network_dynamics.graph import DenseGraph +from tvboptim.experimental.network_dynamics.noise import AdditiveNoise +from tvboptim.experimental.network_dynamics.solvers import Heun +from tvboptim.experimental.network_dynamics.analysis.lyapunov import ( + _lyapunov_spectrum_jvp, +) +from tvboptim.observations.observation import compute_fc, fc_corr, rmse +from tvboptim.observations.tvb_monitors.bold import HRFBold +from tvboptim.optim.optax import OptaxOptimizer +from tvboptim.optim.callbacks import ( + DefaultPrintCallback, + MultiCallback, + SavingCallback, +) + +set_cache_path("./tbptt") + +# Semantic palette (self-contained, no external style file). +INK = "#003754" # text / reference lines +ACCENT = "#0a5170" # accurate gradient / differentiated windows +WARN = "#AF1821" # wrong or unstable gradient / stop-gradient cuts +SHIP = "#3E8E68" # TBPTT run +ROAD = "#6b7280" # boundaries, guides +WIN_FILL = "#dce4f0" +``` + +## The Jansen-Rit Whole-Brain Model + +We use a whole-brain Jansen-Rit network on Desikan-Killiany structural +connectivity, with instantaneous sigmoidal coupling scaled by the global +coupling $G$ and additive noise on the $y_4$ population. Sweeping $G$ takes the +network from a quiescent fixed point, through synchronized oscillation, into +chaos, which is exactly the regime where long-rollout gradients become +dangerous. The empirical FC is the optimization target. + +```{python} +#| echo: true +#| output: false +# Model constants (match the chaos regime used on the poster). +A, MU, SIGMA, DT = 0.1, 0.08, 0.0316, 1.0 +G_LOW, G_HIGH, N_G = 0.0, 40.0, 80 # G sweep grid (divisible by N_DEVICES) +G_RANGE = np.linspace(G_LOW, G_HIGH, N_G) + +WARMUP_T = 20_000.0 # settle out the transient before any measurement (ms) +FC_T1 = 60_000.0 # long BOLD/FC rollout (ms): the gradient horizon at stake +BOLD_PERIOD = 720.0 # BOLD sampling period (ms) +BOLD_SKIP = 25 # BOLD samples dropped before FC +LE_SEGMENT_T = 1_000.0 # Lyapunov segment length (ms) +LE_N = 5 # Lyapunov rescaling steps + +assert N_G % N_DEVICES == 0, "N_G must be divisible by N_DEVICES for the pmap sweep" + + +def build_network(weights, region_labels, G=8.0): + """Jansen-Rit whole-brain network with instantaneous sigmoidal coupling.""" + return Network( + dynamics=JansenRit(a=A, mu=MU), + coupling={"instant": SigmoidalJansenRit(incoming_states=("y1", "y2"), G=G)}, + graph=DenseGraph(weights, region_labels=region_labels), + noise=AdditiveNoise(sigma=SIGMA, apply_to="y4"), + ) + + +coupling_key = "instant" +G_lens = lambda c: c.coupling[coupling_key].G # noqa: E731 + +weights, lengths, region_labels = load_structural_connectivity("dk_average") +weights = weights / jnp.max(weights) +fc_target = load_functional_connectivity("dk_average") + +network = build_network(weights, region_labels) + +# Warm up to a settled state; the FC sim's BOLD monitor reuses it as history. +warm_solve, warm_cfg = prepare(network, Heun(), t1=WARMUP_T, dt=DT) +warm_res = jax.block_until_ready(jax.jit(warm_solve)(warm_cfg)) +network.update_history(warm_res) + +# Two prepared solve functions: the long FC/BOLD rollout, and a short Lyapunov +# segment used to locate the chaos onset. +solve_fc, config_fc = prepare(network, Heun(), t1=FC_T1, dt=DT) +solve_le, config_le = prepare(network, Heun(), t1=LE_SEGMENT_T, dt=DT) +bold = HRFBold(period=BOLD_PERIOD, voi=0, history=warm_res) +``` + +## The Problem: the Exact Gradient Breaks at the Bifurcation + +The loss is the RMSE between the simulated FC and the empirical target. We sweep +$G$ and, at each value, compute the exact reverse-mode AD gradient $dL/dG$, a +finite-difference (FD) ground truth on the same loss, and the top Lyapunov +exponent $\lambda_{\max}$ that marks the chaos onset. + +The FD control is the honest reference here. A single secant is butterfly +dominated in the chaotic band, so we use a central difference of the +**seed-averaged** loss: at each $G$ we run several forward sims at $G \pm \delta$ +that share one noise key per seed (common random numbers), average, then form the +central difference. Forward sims carry no autodiff tape, so this stays affordable +even though the reverse-mode AD leg is expensive. + +```{python} +#| echo: true +#| output: false +FD_DELTA = 0.3 # finite central-difference step in G +N_FD_SEEDS = 8 # seeds averaged for the FD ground truth +FD_SEED_BASE = 0 +fd_keys = jax.random.split(jax.random.key(FD_SEED_BASE), N_FD_SEEDS) + + +def fc_rmse_loss(cfg): + """FC-RMSE loss for a config: long BOLD rollout -> FC -> RMSE vs target.""" + fc = compute_fc(bold(solve_fc(cfg)), skip_t=BOLD_SKIP) + return rmse(fc, fc_target) + + +def run_motivation(cfg): + """Per-G probe: [loss, AD grad, seed-averaged FD grad, FD sem, lambda_max].""" + G0 = G_lens(cfg) + + def f(g): # single-seed loss at the config's own noise key (AD value + grad) + return fc_rmse_loss(eqx.tree_at(G_lens, cfg, g)) + + val, grad_ad = jax.value_and_grad(f)(G0) + + def central_fd(key): # common random numbers across G +/- delta + c = eqx.tree_at(lambda c: c.noise.key, cfg, key) + fp = fc_rmse_loss(eqx.tree_at(G_lens, c, G0 + FD_DELTA)) + fm = fc_rmse_loss(eqx.tree_at(G_lens, c, G0 - FD_DELTA)) + return (fp - fm) / (2.0 * FD_DELTA) + + fd_seeds = jax.lax.map(central_fd, fd_keys) # sequential over seeds + fd_mean = jnp.mean(fd_seeds) + fd_sem = jnp.std(fd_seeds) / jnp.sqrt(N_FD_SEEDS) + + # Top Lyapunov exponent on a short segment at this G. The public + # `lyapunov_spectrum(network, Heun(), t, n, k=1)` is the single-network + # equivalent; here we reuse the prepared segment solver so it vmaps cleanly. + cfg_le = eqx.tree_at(G_lens, config_le, jnp.asarray(G0)) + lam = _lyapunov_spectrum_jvp(solve_le, cfg_le, t=LE_SEGMENT_T, n=LE_N, k=1)[0] + + return jnp.stack([val, grad_ad, fd_mean, fd_sem, lam]) + + +@cache("motivation") +def compute_motivation(): + _, sweep_cfg = prepare(network, Heun(), t1=FC_T1, dt=DT) + sweep_cfg.coupling[coupling_key].G = GridAxis(G_LOW, G_HIGH, N_G) + space = Space(sweep_cfg, mode="product") + results = ParallelExecution(run_motivation, space, n_pmap=N_DEVICES).run() + s = np.array([np.asarray(results[i]) for i in range(len(results))]) + return { + "g": np.asarray(G_RANGE), + "loss": s[:, 0], # FC-RMSE(G): also the optimization profile + "grad_ad": s[:, 1], + "grad_fd": s[:, 2], + "fd_sem": s[:, 3], + "lam": s[:, 4], + } + + +mot = compute_motivation() + + +def g_crit_from(lam, g): + """First G where lambda_max crosses zero (linear interp), or None.""" + pos = np.where(lam > 0)[0] + if pos.size == 0 or pos[0] == 0: + return None + i = pos[0] + return float(np.interp(0.0, [lam[i - 1], lam[i]], [g[i - 1], g[i]])) + + +G_CRIT = g_crit_from(mot["lam"], mot["g"]) +``` + +```{python} +#| label: fig-gradient-breaks +#| fig-cap: "**The exact FC-loss gradient breaks at the chaos onset.** Top: the FC-RMSE loss $L(G)$ (blue, left axis) and the top Lyapunov exponent $\\lambda_{\\max}$ (red, right axis), which crosses zero at the bifurcation into chaos (red band). The loss optimum (blue dot) sits just below that boundary. Bottom: the gradient $dL/dG$. The seed-averaged finite-difference gradient (line + band) is the ground truth; the autodiff gradient (markers) is colored by how it compares: accurate (within 1%), correct direction, wrong sign, or numerically unstable (NaN/Inf). In the fixed-point and limit-cycle regime the AD gradient matches FD; once $G$ crosses the $\\lambda_{\\max}=0$ boundary it flips sign and blows up, while the finite-difference reference stays bounded. This is the failure truncation is built to avoid." +#| code-fold: true +#| code-summary: "Show plotting code" +g = mot["g"] +loss, grad_ad, grad_fd, fd_sem, lam = ( + mot["loss"], mot["grad_ad"], mot["grad_fd"], mot["fd_sem"], mot["lam"] +) + + +def categorize(ad, fd, eps=0.01): + ad, fd = np.asarray(ad, float), np.asarray(fd, float) + ad_ok = ~(np.isnan(ad) | np.isinf(ad)) + rel = np.where(np.abs(fd) > 1e-12, np.abs(ad - fd) / np.abs(fd), np.inf) + same_sign = np.sign(np.where(ad_ok, ad, 0.0)) == np.sign(fd) + accurate = ad_ok & (rel < eps) + correct = ad_ok & same_sign & ~accurate + wrong = ad_ok & ~same_sign + unstable = ~ad_ok + return accurate, correct, wrong, unstable + + +fig, (ax_top, ax_bot) = plt.subplots( + 2, 1, figsize=(8.2, 6.6), sharex=True, + gridspec_kw={"height_ratios": [1.0, 1.2], "hspace": 0.08}, +) + + +def chaos_band(ax): + """Shade the chaotic band and mark the lambda_max=0 boundary on an axis.""" + if G_CRIT is not None: + ax.axvspan(G_CRIT, g.max(), color=WARN, alpha=0.07, zorder=0) + ax.axvline(G_CRIT, color=ROAD, ls="--", lw=1.2, zorder=0) + + +# ---- top panel: loss value + Lyapunov exponent (chaos onset) ---- +# Both series are sampled per G: scatter the measurements with a thin dashed guide, +# matching the marker style of the gradient panel below. +lam_ax = ax_top.twinx() +chaos_band(ax_top) +ax_top.plot(g, loss, "--", color=ACCENT, lw=1.0, alpha=0.5, zorder=2) +ax_top.scatter(g, loss, color=ACCENT, s=14, zorder=3) +kbest = int(np.nanargmin(loss)) +ax_top.scatter([g[kbest]], [loss[kbest]], color=ACCENT, s=90, edgecolor="black", + linewidth=1.0, zorder=4, label="Optimum") +ax_top.set_ylabel("FC RMSE $L(G)$", color=ACCENT) +ax_top.tick_params(axis="y", labelcolor=ACCENT) +ax_top.legend(loc="upper left", fontsize=9, framealpha=0.5) + +lam_s = 1000.0 * lam +lam_ax.plot(g, lam_s, "--", color=WARN, lw=1.0, alpha=0.5, zorder=2) +lam_ax.scatter(g, lam_s, color=WARN, s=14, alpha=0.85, zorder=3) +lam_ax.axhline(0.0, color=WARN, lw=0.6, ls=":", alpha=0.6) +lam_ax.set_ylabel(r"$\lambda_{max}$ (1/s)", color=WARN) +lam_ax.tick_params(axis="y", labelcolor=WARN) +if G_CRIT is not None: + lam_ax.text(G_CRIT, 0.97, r" $\lambda_{max}=0$", color=ROAD, va="top", + ha="left", transform=lam_ax.get_xaxis_transform(), fontsize=10) + +# ---- bottom panel: gradient vs FD ground truth ---- +# Adaptive symlog scale tied to the FD magnitude, so the physical gradient fills +# the frame and the chaotic-band AD blowups clip to the edge instead of stretching +# the axis over empty decades. +fd_abs = np.abs(grad_fd[np.isfinite(grad_fd)]) +cap = float(6.0 * np.nanmax(fd_abs)) +lin = float(max(0.3 * np.nanmedian(fd_abs), 1e-9)) + +chaos_band(ax_bot) +ax_bot.fill_between(g, grad_fd - fd_sem, grad_fd + fd_sem, color=INK, alpha=0.18, lw=0) +ax_bot.plot(g, grad_fd, "x--", color=INK, alpha=0.85, + label="Finite difference (seed-averaged, ground truth)") + +acc, cor, wro, uns = categorize(grad_ad, grad_fd) +ad_disp = np.clip(np.asarray(grad_ad, float), -cap * 0.96, cap * 0.96) +for mask, color, marker, label in [ + (acc, ACCENT, "o", "AD accurate"), + (cor, "#6f9bd1", "o", "AD correct direction"), + (wro, WARN, "s", "AD wrong sign"), + (uns, WARN, "X", "AD unstable (NaN/Inf)"), +]: + if np.any(mask): + y = np.where(uns, 0.0, ad_disp)[mask] if label.endswith("(NaN/Inf)") else ad_disp[mask] + ax_bot.scatter(g[mask], y, color=color, marker=marker, s=42, + edgecolor="black", linewidth=0.5, zorder=3, label=label) + +ax_bot.set_yscale("symlog", linthresh=lin, linscale=0.8) +ax_bot.set_ylim(-cap, cap) +ax_bot.axhline(0.0, color=ROAD, lw=0.8, alpha=0.5) +ax_bot.set_xlabel(r"global coupling $G$") +ax_bot.set_ylabel(r"$dL/dG$") +ax_bot.legend(loc="lower left", fontsize=8, framealpha=0.5) +plt.tight_layout() +``` + +## The Fix: Truncating the Gradient Horizon + +Truncated backpropagation through time keeps the long forward rollout but pulls +the gradient back only through a fixed-length window. In `tvboptim` it is a +single solver knob, `grad_horizon`, on the native solvers. + +The knob is easy to mis-model. The whole `n_steps` rollout runs as one continuous +forward trajectory, then is tiled into windows of $W =$ +`grad_horizon` steps. The loss depends on every window, so **every window is +differentiated**; what truncation does is sever the carried *state* gradient at +each window boundary with `stop_gradient`, so credit cannot flow across a +boundary and the $\exp(T\,\Lambda)$ blow-up is cut off at $W$. The parameter +gradient survives the severing because parameters like $G$ are *closed over* by +the scan body, not threaded through the state carry. The total is the sum of each +window's local contribution: + +$$ +\frac{dL}{dG} \;\approx\; \sum_{k=0}^{n_\text{windows}-1} + \left.\frac{\partial L_k}{\partial G}\right|_{s_k\ \text{detached}} . +$$ + +The window must be shorter than the gradient's memory horizon +$\tau_\lambda \sim 1/\lambda_{\max}$ to tame the chaotic amplification, but long +enough to keep the slow structure the loss actually depends on. + +```{python} +#| label: fig-tbptt-schematic +#| fig-cap: "**How `grad_horizon` tiles the rollout.** One continuous forward pass (top strip, color = time) is split into windows of $W$ steps. Every window is differentiated and contributes its local $\\partial L_k/\\partial G$ (blue), but the carried state gradient is cut at every boundary (red dashed). The shared parameter $G$ is closed over by all windows, so its gradient is the sum of the per-window contributions. The window length is kept below the chaotic memory horizon, $W < \\tau_\\lambda$." +#| code-fold: true +#| code-summary: "Show diagram code" +N_WIN, W = 5, 1.0 +BOX_Y0, BOX_H = 0.0, 0.95 +TOP_Y0, TOP_H = 1.55, 0.30 + +fig, ax = plt.subplots(figsize=(8.4, 2.7)) + +# top: one continuous forward rollout, cividis encodes the time axis +ax.imshow(np.linspace(0, 1, 256)[None, :], extent=(0, N_WIN * W, TOP_Y0, TOP_Y0 + TOP_H), + aspect="auto", cmap="cividis", zorder=1) +ax.add_patch(FancyBboxPatch((0, TOP_Y0), N_WIN * W, TOP_H, + boxstyle="round,pad=0,rounding_size=0.04", fill=False, + edgecolor=INK, lw=1.4, zorder=2)) +ax.text(N_WIN * W / 2, TOP_Y0 + TOP_H + 0.14, + "one continuous forward rollout (long simulation for slow statistics: FC, FCD)", + ha="center", va="bottom", fontsize=12, color=INK) +ax.annotate("", xy=(N_WIN * W - 0.35, TOP_Y0 + TOP_H / 2), xytext=(1.55, TOP_Y0 + TOP_H / 2), + arrowprops=dict(arrowstyle="-|>", color="white", lw=2.2)) +ax.text(0.9, TOP_Y0 + TOP_H / 2, "time", ha="center", va="center", + fontsize=11, color="white", style="italic") + +# windows: each differentiated, local gradient inside, stop-gradient cut at boundaries +for k in range(N_WIN): + x0 = k * W + ax.add_patch(FancyBboxPatch((x0 + 0.03, BOX_Y0), W - 0.06, BOX_H, + boxstyle="round,pad=0,rounding_size=0.05", facecolor=WIN_FILL, + edgecolor=INK, lw=1.6, zorder=3)) + ax.text(x0 + W / 2, BOX_Y0 + BOX_H * 0.66, f"window {k}", ha="center", + va="center", fontsize=12, color=INK) + ax.text(x0 + W / 2, BOX_Y0 + BOX_H * 0.28, r"$\partial L_{%d}/\partial G$" % k, + ha="center", va="center", fontsize=11, color=ACCENT) + ax.add_patch(FancyArrowPatch((x0 + W / 2, BOX_Y0 - 0.04), (x0 + W / 2, BOX_Y0 - 0.30), + arrowstyle="-|>", mutation_scale=16, color=ACCENT, lw=2.0, zorder=4)) + if k > 0: + ax.plot([x0, x0], [BOX_Y0 - 0.02, BOX_Y0 + BOX_H + 0.22], color=WARN, + lw=2.4, ls=(0, (4, 3)), zorder=5) + +ax.text(N_WIN * W / 2, BOX_Y0 + BOX_H + 0.30, "stop-gradient on carried state", + ha="center", va="bottom", fontsize=11, color=WARN, weight="bold") +ax.annotate("", xy=(W - 0.03, BOX_Y0 + BOX_H + 0.12), xytext=(0.03, BOX_Y0 + BOX_H + 0.12), + arrowprops=dict(arrowstyle="<|-|>", color=INK, lw=1.4)) +ax.text(W / 2, BOX_Y0 + BOX_H + 0.16, r"$W < \tau_\lambda$", ha="center", + va="bottom", fontsize=12, color=INK) + +# bottom: shared parameter, gradient is the sum over windows +sum_top, box_h, box_w = BOX_Y0 - 0.36, 0.46, 4.8 +ax.add_patch(FancyBboxPatch((N_WIN * W / 2 - box_w / 2, sum_top - box_h), box_w, box_h, + boxstyle="round,pad=0.02,rounding_size=0.08", facecolor=INK, + edgecolor="none", zorder=4)) +ax.text(N_WIN * W / 2, sum_top - box_h / 2, + r"shared $G$: $dL/dG \;=\; \sum\, \partial L_k/\partial G$", + ha="center", va="center", fontsize=11, color="white", zorder=5) + +ax.set_xlim(-0.45, N_WIN * W + 0.45) +ax.set_ylim(sum_top - box_h - 0.15, TOP_Y0 + TOP_H + 0.5) +ax.axis("off") +plt.tight_layout() +``` + +The truncation window is purely a solver setting: `stop_gradient` is the identity +on the forward pass, so the simulated trajectory, the FC and the loss *value* are +bit-identical to the untruncated run for every window. Only the backward pass +changes. + +## The Payoff: Optimizing $G$ With and Without Truncation + +Now we fit $G$ to the empirical FC by minimizing the FC-RMSE, starting from a +sub-critical $G$ and taking the same optimizer and learning rate in two runs that +differ only in how the gradient is obtained: + +- **Full AD** runs forward-mode autodiff through the untruncated solver. For a + single scalar parameter, forward-mode AD equals full reverse-mode BPTT (same + $dL/dG$), so this trajectory is the exact, untruncated gradient. +- **TBPTT** replaces the backward pass with a windowed solver + `Heun(grad_horizon=W)`, assigning credit only over the last $W$ steps. + +```{python} +#| echo: true +#| output: false +G_START = 5.0 # sub-critical start +LR = 1.0 +MAX_STEPS = 100 +PRINT_EVERY = 25 +W_TBPTT = 100 # gradient horizon for the truncated run (steps) + + +def make_rmse_loss(solve): + """FC-RMSE loss through a given solve function (full or windowed).""" + def loss(state): + fc = compute_fc(bold(solve(state)), skip_t=BOLD_SKIP) + err = rmse(fc, fc_target) + return err, {"rmse": err, "G": jnp.asarray(G_lens(state))} + return loss + + +def save_traj(i, diff_state, static_state, fitting_data, aux, loss_value, grads): + return {"step": int(i), "G": float(aux["G"]), "rmse": float(aux["rmse"])} + + +def run_optim(loss_fn, cfg, mode, label): + """One optimization from G_START; returns (steps, G, rmse) arrays.""" + cfg.coupling[coupling_key].G = Parameter(jnp.asarray(G_START)) + cb = MultiCallback([ + DefaultPrintCallback(every=PRINT_EVERY), + SavingCallback(key="traj", save_fun=save_traj), + ]) + opt = OptaxOptimizer(loss_fn, optax.adam(LR), callback=cb, has_aux=True) + print(f"\n=== {label} ===") + _, data = opt.run(cfg, max_steps=MAX_STEPS, mode=mode) + rows = list(data["traj"]["save"]) + return (np.array([r["step"] for r in rows]), + np.array([r["G"] for r in rows]), + np.array([r["rmse"] for r in rows])) + + +@cache("optim_runs") +def compute_optim(): + # Full AD (forward-mode == full BPTT for scalar G), untruncated solver. + _, cfg_ad = prepare(network, Heun(), t1=FC_T1, dt=DT) + ad = run_optim(make_rmse_loss(solve_fc), cfg_ad, "fwd", "Full AD (no truncation)") + + # TBPTT: identical forward dynamics, backward pass windowed to W*dt. + solve_w, cfg_w = prepare(network, Heun(grad_horizon=W_TBPTT), t1=FC_T1, dt=DT) + tb = run_optim(make_rmse_loss(solve_w), cfg_w, "rev", f"TBPTT W={W_TBPTT}") + return {"ad": ad, "tb": tb} + + +runs = compute_optim() + + +def _last_finite(arr): + a = np.asarray(arr, float) + fin = np.flatnonzero(np.isfinite(a)) + return float(a[fin[-1]]) if fin.size else float("nan") + + +@cache("final_fc") +def compute_final_fc(): + """Simulated FC at each run's optimized G (last finite step), vs the target.""" + def fc_at(G): + cfg = eqx.tree_at(G_lens, config_fc, jnp.asarray(float(G))) + fc = compute_fc(bold(solve_fc(cfg)), skip_t=BOLD_SKIP) + return np.asarray(fc), float(rmse(fc, fc_target)), float(fc_corr(fc, fc_target)) + + out = {"fc_target": np.asarray(fc_target)} + for key in ("ad", "tb"): + G = _last_finite(runs[key][1]) + fc, err, r = fc_at(G) + out[key] = {"G": G, "fc": fc, "rmse": err, "r": r} + return out + + +final_fc = compute_final_fc() +``` + +```{python} +#| label: fig-optim-trajectories +#| fig-cap: "**Optimizing $G$ to fit FC, with and without truncation.** Top: each trajectory is one run; the y-axis is $G$ at each optimization step, points colored by the FC-RMSE they reached (lower is better). The dashed line is the $\\lambda_{max}=0$ boundary and the red band above it is the chaotic regime. The grey curve on the left is the $RMSE(G)$ profile from the sweep above, bulging right at the optimum. The full-AD run is driven by the unstable gradient into the chaotic band; the TBPTT ($W=100$) run descends on a stable, short-horizon gradient and settles near the RMSE optimum on the stable side. Bottom: the simulated FC at each run's optimized $G$ beside the empirical target. The TBPTT fit reproduces the target structure; the full-AD run, stranded in the chaotic band, does not." +#| code-fold: true +#| code-summary: "Show plotting code" +ad_run, tb_run = runs["ad"], runs["tb"] + +fig = plt.figure(figsize=(9.0, 8.0), layout="constrained") +gs = fig.add_gridspec(2, 1, height_ratios=[1.5, 1.0]) +ax = fig.add_subplot(gs[0]) +gs_fc = gs[1].subgridspec(1, 3, wspace=0.08) +fc_axes = [fig.add_subplot(gs_fc[0, j]) for j in range(3)] + +# ---- top panel: step-vs-G trajectories ---- +g_top = max(np.nanmax(ad_run[1]), np.nanmax(tb_run[1])) + +# chaos boundary + band +if G_CRIT is not None: + ax.axhline(G_CRIT, color=ROAD, ls="--", lw=1.5, zorder=0) + ax.axhspan(G_CRIT, g_top * 1.05, color=WARN, alpha=0.07, zorder=0) + ax.text(0.99, G_CRIT, r"$\lambda_{max}=0$ ", color=ROAD, va="bottom", ha="right", + transform=ax.get_yaxis_transform(), fontsize=10) + +# sideways RMSE(G) reference profile (bulges right at the minimum = optimum) +PROFILE_X0, PROFILE_W = 0.0, 0.18 * MAX_STEPS +err_prof = mot["loss"] +emin, emax = float(np.nanmin(err_prof)), float(np.nanmax(err_prof)) +if emax > emin: + xprof = PROFILE_X0 + PROFILE_W * (emax - err_prof) / (emax - emin) + ax.fill_betweenx(mot["g"], PROFILE_X0, xprof, color=ROAD, alpha=0.12, lw=0, zorder=0) + ax.plot(xprof, mot["g"], color=ROAD, lw=1.2, alpha=0.6, zorder=1) + kbest = int(np.nanargmin(err_prof)) + ax.annotate(r"RMSE$(G)$", xy=(xprof[kbest], mot["g"][kbest]), xytext=(6, 0), + textcoords="offset points", color=ROAD, fontsize=11, va="center") + +# trajectories, colored by RMSE on a shared scale +rmse_all = np.concatenate([ad_run[2], tb_run[2]]) +vmin, vmax = float(np.nanmin(rmse_all)), float(np.nanmax(rmse_all)) +runs_plot = [ + (ad_run, "s", WARN, "Full AD (no truncation)"), + (tb_run, "^", SHIP, f"TBPTT ($W={W_TBPTT}$)"), +] +sc = None +for (steps, G, err), marker, line_c, _ in runs_plot: + ax.plot(steps, G, color=line_c, lw=1.0, alpha=0.4, zorder=1) + sc = ax.scatter(steps, G, c=err, cmap="cividis_r", vmin=vmin, vmax=vmax, + marker=marker, s=42, edgecolor="black", linewidth=0.4, zorder=3) + +ax.set_xlabel("optimization step") +ax.set_ylabel(r"global coupling $G$") +fig.colorbar(sc, ax=ax, label="FC RMSE") +handles = [Line2D([0], [0], color=c, marker=m, lw=1.0, mec="black", label=lab) + for _, m, c, lab in runs_plot] +ax.legend(handles=handles, loc="best", fontsize=10) + +# ---- bottom panel: simulated FC at each run's optimum vs the empirical target ---- +panels = [ + ("Empirical target", final_fc["fc_target"], None), + (rf"TBPTT ($W={W_TBPTT}$)", final_fc["tb"]["fc"], final_fc["tb"]), + ("Full AD", final_fc["ad"]["fc"], final_fc["ad"]), +] +mats = np.array([p[1] for p in panels]) +offdiag = ~np.eye(mats.shape[1], dtype=bool) +vmax_fc = float(np.nanpercentile(mats[:, offdiag], 95)) +im = None +for fa, (title, mat, info) in zip(fc_axes, panels): + m = np.array(mat, float) + np.fill_diagonal(m, np.nan) + im = fa.imshow(m, cmap="cividis", vmin=0.0, vmax=vmax_fc, aspect="equal", interpolation="none") + fa.set_xticks([]) + fa.set_yticks([]) + sub = title if info is None else ( + f"{title}\n$G={info['G']:.1f}$, RMSE$={info['rmse']:.3f}$, " + rf"$r_{{FC}}={info['r']:.3f}$" + ) + fa.set_title(sub, fontsize=10, color=INK) +fig.colorbar(im, ax=fc_axes, label="FC", shrink=0.8) +``` + +## Interpretation and Practical Recipe + +- **The exact gradient is correct until the dynamics turn chaotic.** Below the + $\lambda_{\max}=0$ boundary the AD gradient agrees with the finite-difference + ground truth. Above it the AD gradient flips sign and diverges, so an optimizer + driven by it is pushed into the chaotic regime instead of toward the fit. +- **Truncation restores a usable gradient.** Severing the carry gradient at the + window boundary caps the backward horizon at $W\,dt$, so the runaway + amplification never accumulates. The windowed run descends on a stable gradient + and settles near the FC optimum on the stable side. +- **The window is a real hyperparameter.** A too-short window is blind to slow + FC structure and biases toward fast timescales; a too-long window lets the + chaotic amplification back in. Choose $W$ from the slowest timescale you need + credit for, not from the simulation length. + +::: {.callout-note} +## Forward is unchanged + +`grad_horizon` only changes the backward pass. The forward trajectory, and hence +the FC and the loss *value*, are bit-identical to the untruncated run for every +window. The same long simulation can serve a long-horizon statistic while the +gradient is taken over a short, stable window. +::: + +`grad_horizon` sets the gradient horizon, not the cost. Bounding runtime and +memory is the job of `block_size` (gradient checkpointing), an orthogonal knob +that rematerializes activations within each window and nests with `grad_horizon`. +The same `block_size` seam can also stream the FC so memory scales with the block +rather than `n_steps`. See the [RWW tutorial](RWW.qmd) for the full FC/BOLD +pipeline. + +## References + +Truncated backpropagation through time: + +- Werbos, P. J. (1990). Backpropagation through time: what it does and how to do + it. *Proceedings of the IEEE*, 78(10), 1550-1560. +- Williams, R. J., & Peng, J. (1990). An efficient gradient-based algorithm for + on-line training of recurrent network trajectories. *Neural Computation*, 2(4), + 490-501. The windowed truncated-BPTT scheme used here. +- Tallec, C., & Ollivier, Y. (2017). Unbiasing truncated backpropagation through + time. *arXiv:1705.08209*. On the bias of a too-short window. + +Why the gradient horizon is finite (the $\exp(T\,\Lambda)$ growth): + +- Bengio, Y., Simard, P., & Frasconi, P. (1994). Learning long-term dependencies + with gradient descent is difficult. *IEEE Transactions on Neural Networks*, + 5(2), 157-166. +- Pascanu, R., Mikolov, T., & Bengio, Y. (2013). On the difficulty of training + recurrent neural networks. *ICML*, PMLR 28(3), 1310-1318. Ties the gradient + growth to the recurrent Jacobian's spectral radius (the $\Lambda$ above). + +Why an FC (covariance) loss's horizon is the system's correlation time: + +- Kubo, R. (1966). The fluctuation-dissipation theorem. *Reports on Progress in + Physics*, 29(1), 255-284. diff --git a/docs/workflows/cache/jr/optimize.pkl b/docs/workflows/cache/jr/optimize.pkl index a832d9d..10ccbed 100644 Binary files a/docs/workflows/cache/jr/optimize.pkl and b/docs/workflows/cache/jr/optimize.pkl differ diff --git a/docs/workflows/cache/rww/explore.pkl b/docs/workflows/cache/rww/explore.pkl index 5454456..92f1e55 100644 Binary files a/docs/workflows/cache/rww/explore.pkl and b/docs/workflows/cache/rww/explore.pkl differ diff --git a/docs/workflows/cache/rww/optimize.pkl b/docs/workflows/cache/rww/optimize.pkl index 0657395..3f980d1 100644 Binary files a/docs/workflows/cache/rww/optimize.pkl and b/docs/workflows/cache/rww/optimize.pkl differ diff --git a/docs/workflows/cache/rww/optimize_het.pkl b/docs/workflows/cache/rww/optimize_het.pkl index 8a6329f..1d719f9 100644 Binary files a/docs/workflows/cache/rww/optimize_het.pkl and b/docs/workflows/cache/rww/optimize_het.pkl differ diff --git a/docs/workflows/cache/tbptt/final_fc.pkl b/docs/workflows/cache/tbptt/final_fc.pkl new file mode 100644 index 0000000..50147ef Binary files /dev/null and b/docs/workflows/cache/tbptt/final_fc.pkl differ diff --git a/docs/workflows/cache/tbptt/motivation.pkl b/docs/workflows/cache/tbptt/motivation.pkl new file mode 100644 index 0000000..a917d6d Binary files /dev/null and b/docs/workflows/cache/tbptt/motivation.pkl differ diff --git a/docs/workflows/cache/tbptt/optim_runs.pkl b/docs/workflows/cache/tbptt/optim_runs.pkl new file mode 100644 index 0000000..9901fd1 Binary files /dev/null and b/docs/workflows/cache/tbptt/optim_runs.pkl differ diff --git a/docs/workflows/cache/tbptt/tbptt_sweep.pkl b/docs/workflows/cache/tbptt/tbptt_sweep.pkl new file mode 100644 index 0000000..a82edbb Binary files /dev/null and b/docs/workflows/cache/tbptt/tbptt_sweep.pkl differ diff --git a/pyproject.toml b/pyproject.toml index c7fdd98..1c34e54 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "tvboptim" -version = "0.3.0" +version = "0.3.1" description = "Optimization tools for The Virtual Brain" readme = "README.md" authors = [ @@ -100,6 +100,14 @@ docs = [ "griffe<2.0.0" ] +[tool.uv] +# tvb-library depends on numba without an upper or lower bound, so a fresh +# resolution can backtrack numba down to 0.53.1 (-> llvmlite 0.36.0), which +# only builds on Python <3.10 and breaks CI on the >=3.11 matrix. Floor numba +# at a version that ships wheels for our supported Pythons. Constraints only +# bound transitively-pulled packages; they do not add a direct dependency. +constraint-dependencies = ["numba>=0.61", "llvmlite>=0.44"] + [tool.ruff] line-length = 88 target-version = "py311" diff --git a/src/tvboptim/analysis/identifiability.py b/src/tvboptim/analysis/identifiability.py index ced70e5..ecd6544 100644 --- a/src/tvboptim/analysis/identifiability.py +++ b/src/tvboptim/analysis/identifiability.py @@ -235,8 +235,7 @@ def summary(self) -> str: """Human-readable one-screen report.""" label = _KIND_LABELS.get(self.kind, self.kind) lines = [ - f"Identifiability analysis -- {label} " - f"({self.n_params} parameters)", + f"Identifiability analysis -- {label} ({self.n_params} parameters)", "-" * 60, ] if self.gradient_norm is not None: @@ -423,9 +422,7 @@ def eigendecompose_curvature( downstream in ``summary()`` / ``__repr__``. """ if kind not in ("hessian", "fisher"): - raise ValueError( - f"`kind` must be 'hessian' or 'fisher', got {kind!r}." - ) + raise ValueError(f"`kind` must be 'hessian' or 'fisher', got {kind!r}.") matrix = jnp.asarray(matrix) matrix = 0.5 * (matrix + matrix.T) # symmetrize against floating-point asymmetry eigenvalues, eigenvectors = jnp.linalg.eigh(matrix) diff --git a/src/tvboptim/experimental/network_dynamics/analysis/__init__.py b/src/tvboptim/experimental/network_dynamics/analysis/__init__.py index 2287e72..fb83ead 100644 --- a/src/tvboptim/experimental/network_dynamics/analysis/__init__.py +++ b/src/tvboptim/experimental/network_dynamics/analysis/__init__.py @@ -1,3 +1,9 @@ +from .adiabatic_scan import AdiabaticScanResult, adiabatic_scan from .lyapunov import lyapunov, lyapunov_spectrum -__all__ = ["lyapunov", "lyapunov_spectrum"] +__all__ = [ + "lyapunov", + "lyapunov_spectrum", + "adiabatic_scan", + "AdiabaticScanResult", +] diff --git a/src/tvboptim/experimental/network_dynamics/analysis/adiabatic_scan.py b/src/tvboptim/experimental/network_dynamics/analysis/adiabatic_scan.py new file mode 100644 index 0000000..2a67f62 --- /dev/null +++ b/src/tvboptim/experimental/network_dynamics/analysis/adiabatic_scan.py @@ -0,0 +1,183 @@ +"""Adiabatic parameter scan: a network bifurcation diagram. + +Slowly ramp one parameter from ``low`` to ``high`` (and optionally back down to +catch hysteresis), carrying the settled state forward between steps, and record +summary statistics of an observed network signal at each value. This traces a +bifurcation-diagram-like picture of how the network's activity changes with the +swept parameter. + +The swept parameter is addressed with a lens ``accessor`` applied through +``eqx.tree_at`` (e.g. ``lambda c: c.coupling.instant.G``), so any nested config +field can be scanned without the function knowing its name. +""" + +from dataclasses import dataclass +from typing import Callable, Dict + +import equinox as eqx +import jax +import jax.numpy as jnp +import numpy as np + +from .. import prepare +from ..core.bunch import Bunch +from ..solvers import Heun + + +@dataclass +class AdiabaticScanResult: + """Result of an :func:`adiabatic_scan`. + + Attributes + ---------- + p : jax.Array + The swept parameter values, in scan order. Length ``2*n`` when + ``bothways`` (up then down), else ``n``. + n_up : int + Number of values in the ascending branch. ``p[:n_up]`` is the up-branch + and ``p[n_up:]`` the down-branch. + stats : Bunch of str -> jax.Array + One array per recorded statistic, stacked along the scan axis and + reachable by attribute (``stats.mean``) or key (``stats["mean"]``). + Shape is ``[len(p)]`` for scalar reducers and ``[len(p), ...]`` for + vector-valued reducers (e.g. ``[len(p), n_nodes]`` for a per-node + statistic). + """ + + p: jax.Array + n_up: int + stats: Bunch + + +def _default_observe(result): + """Observe the first variable across all nodes -> [n_time, n_nodes].""" + return result.ys[:, 0, :] + + +def _network_mean(arr): + """Mean over time per node, then averaged across nodes.""" + return jnp.mean(arr, axis=0).mean() + + +def _network_min(arr): + """Mean over time per node, then the minimum across nodes.""" + return jnp.mean(arr, axis=0).min() + + +def _network_max(arr): + """Mean over time per node, then the maximum across nodes.""" + return jnp.mean(arr, axis=0).max() + + +_DEFAULT_STATISTICS = { + "mean": _network_mean, + "min": _network_min, + "max": _network_max, +} + + +def adiabatic_scan( + network, + solver=None, + *, + accessor: Callable, + low: float, + high: float, + n: int, + t: float = 2000.0, + skip: float = 1000.0, + dt: float = 1.0, + t0: float = 0.0, + bothways: bool = True, + observe: Callable = None, + statistics: Dict[str, Callable] = None, +) -> AdiabaticScanResult: + """Ramp one parameter and record network statistics (bifurcation diagram). + + Parameters + ---------- + network : Network + solver : solver instance, optional (default: Heun()) + accessor : callable + Lens onto the swept leaf, used as ``eqx.tree_at(accessor, config, value)``. + Example: ``lambda c: c.coupling.instant.G``. + low, high : float + Bounds of the swept parameter. + n : int + Number of values per branch. + t : float + Simulation duration per step in ms. + skip : float + Transient duration in ms discarded before computing statistics. + dt : float + Integration timestep in ms. + t0 : float + Simulation start time. + bothways : bool + If True, scan up then back down to expose hysteresis. + observe : callable, optional + ``result -> [n_time, n_nodes]`` signal to summarise. Defaults to the + first variable across all nodes. + statistics : dict of str -> callable, optional + Maps a name to a reducer ``[n_time, n_nodes] -> scalar or array``. + Vector-valued reducers (e.g. a per-node ``[n_nodes]`` statistic) are + supported as long as the output shape is the same at every scan point. + Defaults to mean/min/max of the per-node temporal mean across the + network. + + Returns + ------- + AdiabaticScanResult + + Notes + ----- + The settled state is carried forward between steps (the slow, adiabatic + ramp). For networks with delayed coupling the delay history buffer is not + carried, so this is only exact for instantaneous (non-delayed) coupling. + """ + if solver is None: + solver = Heun() + if observe is None: + observe = _default_observe + if statistics is None: + statistics = _DEFAULT_STATISTICS + + solve_fn, config = prepare(network, solver, t0=t0, t1=t0 + t, dt=dt) + + p_up = jnp.linspace(low, high, n) + if bothways: + p = jnp.concatenate([p_up, p_up[::-1]]) + else: + p = p_up + + n_states = config.initial_state.dynamics.shape[0] + init_state = config.initial_state.dynamics + + # The save grid is deterministic in (t0, t1, dt), so the post-transient + # window is the same at every step. Resolve it once to static integer + # indices (a host-side computation) so the scan body has no data-dependent + # shapes and stays jittable. + probe_ts = np.asarray(solve_fn(config).ts) + keep = jnp.asarray(np.flatnonzero(probe_ts > (t0 + skip))) + + def step(state, value): + cfg = eqx.tree_at(accessor, config, value) + cfg = eqx.tree_at(lambda c: c.initial_state.dynamics, cfg, state) + result = solve_fn(cfg) + + # statistics over the post-transient window. Reducer outputs stay on the + # JAX side; lax.scan stacks them along the scan axis, so array-valued + # reducers like a per-node median ([n_nodes]) are supported. + arr = observe(result)[keep] + outs = {name: fn(arr) for name, fn in statistics.items()} + + # carry the settled state forward (the slow, adiabatic ramp) + new_state = result.ys[-1][:n_states] + return new_state, outs + + # The whole sweep is one jitted scan: a single XLA compilation, the carry + # expressed natively, and the function is vmap-able over a batch of `p` + # arrays to explore several ranges in parallel. + _, stacked = jax.jit(lambda s, ps: jax.lax.scan(step, s, ps))(init_state, p) + + return AdiabaticScanResult(p=p, n_up=n, stats=Bunch(stacked)) diff --git a/src/tvboptim/experimental/network_dynamics/analysis/lyapunov.py b/src/tvboptim/experimental/network_dynamics/analysis/lyapunov.py index c8c3977..98a0cd5 100644 --- a/src/tvboptim/experimental/network_dynamics/analysis/lyapunov.py +++ b/src/tvboptim/experimental/network_dynamics/analysis/lyapunov.py @@ -30,10 +30,15 @@ def lyapunov(network, solver=None, t=1000.0, n=10, d0=1e-9, dt=0.1, t0=0.0): Notes ----- - For networks with delayed coupling the delay history is held fixed - across segments (good approximation when t >> max_delay). Run a - warmup via network.update_history(result) before calling this - function to initialise the delay buffer from a settled trajectory. + For delayed coupling the history buffer is held fixed across segments + (each rescaling step re-seeds only the point state), so the result is a + good approximation only when t >> max_delay, not exact. The two-trajectory + method is the natural route to a delay-correct MLE -- each trajectory + evolves its own history, so no explicit history-space tangents are needed -- + but that needs carrying the full final state (point + history) across + segments, or running unsegmented. Run a warmup via + network.update_history(result) first to initialise the delay buffer from a + settled trajectory. """ if solver is None: solver = Heun() @@ -95,6 +100,15 @@ def lyapunov_spectrum( Returns ------- jnp.ndarray — top k Lyapunov exponents sorted descending (1/ms) + + Notes + ----- + Exact for instantaneous (non-delayed) coupling only. Both modes propagate a + *point* state (perturb/linearize only ``initial_state.dynamics``) and reset + the delay history buffer each rescaling step, so for delayed coupling the + history is held fixed across segments: a good approximation only when + ``t >> max_delay``, not the true DDE spectrum (which would need tangents + spanning the augmented history-buffer state). """ if solver is None: solver = Heun() diff --git a/src/tvboptim/experimental/network_dynamics/external_input/base.py b/src/tvboptim/experimental/network_dynamics/external_input/base.py index ea619aa..23de4b5 100644 --- a/src/tvboptim/experimental/network_dynamics/external_input/base.py +++ b/src/tvboptim/experimental/network_dynamics/external_input/base.py @@ -211,7 +211,9 @@ def plot( show_legend = n_nodes <= 8 for d in range(n_dims): for i in range(n_nodes): - ax[d].plot(ts, signals[:, d, i], label=f"node {i}" if show_legend else None) + ax[d].plot( + ts, signals[:, d, i], label=f"node {i}" if show_legend else None + ) ax[d].set_ylabel(f"input[{d}]" if n_dims > 1 else "input") if show_legend and n_nodes > 1: ax[d].legend(loc="best", fontsize="small") diff --git a/src/tvboptim/experimental/network_dynamics/solve.py b/src/tvboptim/experimental/network_dynamics/solve.py index fb13aef..fea5350 100644 --- a/src/tvboptim/experimental/network_dynamics/solve.py +++ b/src/tvboptim/experimental/network_dynamics/solve.py @@ -5,6 +5,7 @@ management, and returns a pure function for execution. """ +import warnings from typing import Callable, Tuple import diffrax @@ -32,29 +33,24 @@ def _snapshot(tree): return jax.tree.map(lambda x: x, tree) -def _checkpointed_scan(op, state0, scan_inputs, n_steps, block_size): - """Run ``op`` over ``scan_inputs`` as an outer-checkpointed nested scan. +def _blocked_scan(runner, state0, scan_inputs, n_steps, block_size): + """Split the leading axis into ``(n_blocks, block_size)`` plus a tail, scan + ``runner`` over the main blocks, run ``runner`` once on the tail, and stitch + the outputs back to leading shape ``(n_steps, ...)``. - The leading axis of every leaf in ``scan_inputs`` is split into - ``(n_blocks, block_size)``; an outer ``jax.lax.scan`` runs over blocks - with each block wrapped in ``jax.checkpoint``, and an inner - ``jax.lax.scan`` runs the ``block_size`` steps inside a block. When - ``n_steps`` is not a multiple of ``block_size`` the remainder runs through - a plain (uncheckpointed) tail scan; the tail length is fixed at trace - time and is at most ``block_size - 1``. + ``runner(state, block_inputs, block_len) -> (state, outs)`` is the only + per-block behaviour that varies between callers (checkpointed inner scan, + truncated window, per-block streaming/reduce); this helper owns the + reshape/tail/stitch skeleton they all share. ``block_len`` is a static + Python int (``block_size`` for the main blocks, the remainder for the tail). - Output of the scanned computation is stitched back to leading shape - ``(n_steps, ...)`` so the result is indistinguishable from a single - ``jax.lax.scan(op, state0, scan_inputs)`` call to downstream code. + When ``runner`` emits ``None`` outputs (the reduce/fold case) the ``None`` + threads through ``jax.tree.map`` and the concatenate untouched, so the + accumulator-in-carry path needs no special handling here. """ n_blocks = n_steps // block_size remainder = n_steps - n_blocks * block_size - def inner(state, block_inputs): - return jax.lax.scan(op, state, block_inputs) - - ckpt_inner = jax.checkpoint(inner) - if n_blocks > 0: main = jax.tree.map( lambda x: x[: n_blocks * block_size].reshape( @@ -62,7 +58,9 @@ def inner(state, block_inputs): ), scan_inputs, ) - state_mid, outs_main = jax.lax.scan(ckpt_inner, state0, main) + state_mid, outs_main = jax.lax.scan( + lambda s, b: runner(s, b, block_size), state0, main + ) outs_main_flat = jax.tree.map( lambda x: x.reshape((n_blocks * block_size,) + x.shape[2:]), outs_main, @@ -75,14 +73,7 @@ def inner(state, block_inputs): return state_mid, outs_main_flat tail = jax.tree.map(lambda x: x[n_blocks * block_size :], scan_inputs) - # Wrap the tail in jax.checkpoint too so its activation tape is not - # held live during the outer backward through the main blocks. Without - # this, peak memory for non-divisor K becomes - # ``K · c_inner + remainder · c_unchecked`` and large remainders can - # spike above the no-checkpoint baseline. - state_final, outs_tail = jax.checkpoint( - lambda s, t: jax.lax.scan(op, s, t) - )(state_mid, tail) + state_final, outs_tail = runner(state_mid, tail, remainder) if outs_main_flat is None: return state_final, outs_tail @@ -93,6 +84,350 @@ def inner(state, block_inputs): return state_final, outs +def _block_scan(op, state0, scan_inputs, n_steps, block_size): + """Run ``op`` over ``scan_inputs`` as an outer-checkpointed nested scan. + + The leading axis of every leaf in ``scan_inputs`` is split into + ``(n_blocks, block_size)``; an outer ``jax.lax.scan`` runs over blocks + with each block wrapped in ``jax.checkpoint``, and an inner + ``jax.lax.scan`` runs the ``block_size`` steps inside a block. When + ``n_steps`` is not a multiple of ``block_size`` the remainder runs through + a checkpointed tail scan; the tail length is fixed at trace time and is at + most ``block_size - 1``. + + Output of the scanned computation is stitched back to leading shape + ``(n_steps, ...)`` so the result is indistinguishable from a single + ``jax.lax.scan(op, state0, scan_inputs)`` call to downstream code. + """ + # One checkpointed inner scan serves both the main blocks and the tail, so + # every block's activation tape is rematerialized on the backward pass + # rather than held live across the whole rollout; only block-boundary + # carries are retained. Wrapping the tail too keeps peak memory for + # non-divisor block sizes from spiking above the no-checkpoint baseline. + block = jax.checkpoint( + lambda state, block_inputs: jax.lax.scan(op, state, block_inputs) + ) + return _blocked_scan( + lambda state, block_inputs, block_len: block(state, block_inputs), + state0, + scan_inputs, + n_steps, + block_size, + ) + + +def _truncated_scan(op, state0, scan_inputs, n_steps, window_size, block_size): + """Run ``op`` over ``scan_inputs`` as a windowed scan with truncated BPTT. + + The leading axis of every leaf in ``scan_inputs`` is split into + ``(n_windows, window_size)``; an outer ``jax.lax.scan`` runs over windows + and the carry gradient is severed with ``jax.lax.stop_gradient`` at the + entry to each window, so credit is assigned only within a window + (truncated backpropagation through time). Within a window the steps run as + a plain inner ``jax.lax.scan`` when ``block_size`` is None, or as a + ``_block_scan`` over ``block_size`` blocks when it is set (the memory + granularity nests inside the gradient window). When ``n_steps`` is not a + multiple of ``window_size`` the remainder runs as a final shorter window + with the same carry severing and inner runner. + + The forward computation is identical to a single ``jax.lax.scan(op, ...)`` + (``stop_gradient`` is the identity on the forward); only the backward pass + is truncated. Output is stitched back to leading shape ``(n_steps, ...)`` so + the result is indistinguishable from the untruncated scan to downstream code. + """ + + def run_window(state, window_inputs, window_len): + # Sever cross-window credit: the window sees a detached start state, so + # gradient flows within the window (and to op's closed-over parameters) + # but not back to prior windows. Forward value is unchanged. + state = jax.lax.stop_gradient(state) + if block_size is None: + return jax.lax.scan(op, state, window_inputs) + return _block_scan(op, state, window_inputs, window_len, block_size) + + return _blocked_scan(run_window, state0, scan_inputs, n_steps, window_size) + + +def _split_voi(dynamics): + """Split variables-of-interest into state and auxiliary index arrays. + + Returns ``(state_voi_indices, aux_voi_indices, record_auxiliaries, + variable_names)``. ``variable_names`` labels axis 1 of the output + trajectory: selected states first, then selected auxiliaries, matching + the concatenation order in ``_assemble_output``. + """ + voi_indices = dynamics.get_variables_of_interest_indices() + n_states = dynamics.N_STATES + state_voi_indices = jnp.array([i for i in voi_indices if i < n_states], dtype=int) + aux_voi_indices = jnp.array( + [i - n_states for i in voi_indices if i >= n_states], dtype=int + ) + record_auxiliaries = len(aux_voi_indices) > 0 + all_variable_names = dynamics.all_variable_names + variable_names = tuple( + all_variable_names[i] for i in voi_indices if i < n_states + ) + tuple(all_variable_names[i] for i in voi_indices if i >= n_states) + return state_voi_indices, aux_voi_indices, record_auxiliaries, variable_names + + +def _materialize_noise(noise_key, injected, shape): + """Draw the full per-call noise tensor, or pass through an injected override. + + A single fused ``jax.random.normal`` draw of shape ``[n_steps, + n_noise_states, n_nodes]`` when no override is present (XLA fuses this with + the downstream scan); otherwise the caller-supplied tensor verbatim (the + NumPyro-over-increments replay path). Callers gate this on the network + actually having noise. + """ + if injected is None: + return jax.random.normal(noise_key, shape) + return injected + + +def _streaming_noise_gen(noise_key, per_node_shape, provider=None): + """Per-block streaming noise generator for the block scan. + + Returns ``noise_gen(block_idx, block_len) -> [block_len, *per_node_shape]``. + The default draws ``jax.random.normal(jax.random.fold_in(noise_key, + block_idx), ...)``, so the realization is a pure function of + ``(key, block_idx)`` and the block grain, regenerated (not stored) on the + backward pass. ``per_node_shape`` is ``(n_noise_states, n_nodes)``. A + ``provider`` (``config._internal.noise_provider``) overrides the draw for the + streaming injection workflow. + """ + if provider is not None: + return lambda block_idx, block_len: provider( + block_idx, noise_key, (block_len,) + per_node_shape + ) + return lambda block_idx, block_len: jax.random.normal( + jax.random.fold_in(noise_key, block_idx), (block_len,) + per_node_shape + ) + + +def _assemble_output( + next_state, auxiliaries, state_voi_indices, aux_voi_indices, record_auxiliaries +): + """Select one step's variables-of-interest slice of the output. + + Concatenates selected state variables and (optionally) selected + auxiliaries along axis 0, matching the ordering of ``variable_names`` + from ``_split_voi``. + """ + if len(state_voi_indices) > 0: + selected_states = next_state[state_voi_indices] + else: + selected_states = jnp.array([]).reshape(0, next_state.shape[1]) + + if record_auxiliaries and auxiliaries.size > 0: + selected_aux = auxiliaries[aux_voi_indices] + return jnp.concatenate([selected_states, selected_aux], axis=0) + return selected_states + + +def _composed_scan( + block_step, carry0, scan_inputs, n_steps, block_size, window_size=None +): + """Run a block-wise scan with a pluggable per-block ``block_step``. + + ``block_step(carry, block_inputs) -> (carry, outs)`` owns the per-block work + (inner step scan, optional noise generation, stack-or-fold of the output); + it carries whatever state it needs (e.g. ``(state, acc)`` or + ``(state, counter)``). This helper owns only the composition: blocks of + ``block_size`` via ``_blocked_scan``, and, when ``window_size`` is set + (truncated BPTT), an outer window scan that severs the whole carry with + ``jax.lax.stop_gradient`` at each window entry and nests the blocks inside. + Returns ``(final_carry, outs)``; ``outs`` is the stitched trajectory for a + stacking ``block_step`` and ``None`` for a folding one. + """ + + def run_blocks(carry, inputs, length): + return _blocked_scan( + lambda c, b, _len: block_step(c, b), carry, inputs, length, block_size + ) + + if window_size is None: + return run_blocks(carry0, scan_inputs, n_steps) + + def run_window(carry, window_inputs, window_len): + # Sever cross-window credit on the whole carry: the window contributes + # its local d(.)/d(theta) with the window-start carry detached. Forward + # value is unchanged (stop_gradient is the identity on the forward). + carry = jax.lax.stop_gradient(carry) + return run_blocks(carry, window_inputs, window_len) + + return _blocked_scan(run_window, carry0, scan_inputs, n_steps, window_size) + + +def _fold_block(op, update): + """Folding block step (carry ``(state, acc)``), presampled/no-noise path. + + Runs the block's steps through an inner ``jax.lax.scan(op, ...)`` to a + stacked ``[block_len, ...]`` output and folds it into ``acc`` at block + granularity (one batched ``update`` per block, not per step). The whole + block, including the ``update``, is ``jax.checkpoint``-wrapped so it is + rematerialized on the backward pass; only the block-boundary ``(state, acc)`` + is retained. Emits ``None`` (no trajectory stacked). + """ + + @jax.checkpoint + def step(carry, block_inputs): + state, acc = carry + state, block_out = jax.lax.scan(op, state, block_inputs) + return (state, update(acc, block_out)), None + + return step + + +def _stream_block(op, noise_gen): + """Streaming stacking block step (carry ``(state, counter)``). + + Generates this block's noise from the absolute block ordinal ``counter`` + (``noise_gen(counter, block_len)``), threads ``(time_chunk, noise)`` into the + inner step scan, and stacks the block output. ``counter`` increments per + block so the noise is a pure function of ``(key, absolute_block_idx)``, + independent of how truncation windows tile the rollout. ``jax.checkpoint`` + wraps the block so the draw is *regenerated*, not stored, on the backward + pass (the streaming memory win). + """ + + @jax.checkpoint + def step(carry, time_chunk): + state, counter = carry + noise = noise_gen(counter, time_chunk.shape[0]) + state, out = jax.lax.scan(op, state, (time_chunk, noise)) + return (state, counter + 1), out + + return step + + +def _stream_fold_block(op, update, noise_gen): + """Streaming folding block step (carry ``(state, acc, counter)``). + + Combines per-block streaming noise (as in ``_stream_block``) with the + block-level fold of ``_fold_block``: generates the block noise from + ``counter``, runs the inner step scan over ``(time_chunk, noise)``, folds the + block output into ``acc``, and increments ``counter``. ``jax.checkpoint`` + rematerializes the whole block (noise draw + inner scan + update) on the + backward pass. Emits ``None``. + """ + + @jax.checkpoint + def step(carry, time_chunk): + state, acc, counter = carry + noise = noise_gen(counter, time_chunk.shape[0]) + state, block_out = jax.lax.scan(op, state, (time_chunk, noise)) + return (state, update(acc, block_out), counter + 1), None + + return step + + +def _snap_window(window, block_size): + """Snap a truncation window to the nearest multiple of ``block_size``. + + Window boundaries must align with block boundaries so the blocks nest and + (under streaming noise) the absolute block grid is independent of the window + tiling. A non-multiple window is rounded to the nearest multiple (floored at + one block) with a warning, rather than silently changing the gradient + horizon. + """ + if window % block_size == 0: + return window + snapped = max(block_size, round(window / block_size) * block_size) + warnings.warn( + f"grad_horizon={window} is not a multiple of block_size={block_size}; " + f"snapping to {snapped} so window boundaries align with block " + "boundaries.", + stacklevel=2, + ) + return snapped + + +def _reduce_fold(reduce, variable_names, n_nodes, n_steps): + """Build the ``(acc0, update)`` fold pair for ``run_scan``, or None. + + A reducer is the ``(init, update, finalize)`` triple passed as the ``reduce`` + kwarg. ``init(template, n_steps)`` sizes the accumulator: the per-step output + template is ``[n_vois, n_nodes]`` (the selected variables of interest, + matching the leading-after-time shape of a stacked trajectory) and + ``n_steps`` is the rollout length. Time-grid reducers that need ``dt`` (e.g. + BOLD decimation) take it at construction instead, since their per-block + update closes over static strides built before the framework calls ``init``. + """ + if reduce is None: + return None + init, update, _finalize = reduce + template = jnp.zeros((len(variable_names), n_nodes)) + return (init(template, n_steps), update) + + +def run_scan(op, state0, scan_inputs, n_steps, solver, fold=None, noise_gen=None): + """Run the integration scan, dispatching on the solver's gradient/memory knobs. + + The single seam where scan-level features live. Independent, nullable knobs + select the path: + + - ``noise_gen`` (streaming noise): ``noise_gen(block_idx, block_len)`` or + None. Set only when ``block_size`` is set and the network has noise with no + injected tensor; the per-block noise is generated in-scan from the absolute + block ordinal, so ``scan_inputs`` carries only the time signal (no noise + leaf) and the block step combines them. + - ``fold`` (the reduce output handler): ``(acc0, update)`` or None. When set + with ``block_size``, the trajectory is folded block-wise into ``acc`` and + the final carry exposes ``acc`` at index 1; the caller reads it and applies + ``finalize``. Requires ``block_size``; with ``block_size=None`` the caller + folds the stacked trajectory once instead (the degenerate single-block / + post-hoc case). + - ``grad_horizon`` (gradient horizon): if set, run a windowed scan that + severs the carry gradient every ``W`` steps. Snapped to a multiple of + ``block_size`` when both are set so window and block boundaries align. + - ``block_size`` (block granularity): with no truncation, None is the + plain single ``jax.lax.scan`` (the monolithic default, no-regression path) + and an int is the outer block scan that trades recompute for backward + memory and (with ``noise_gen``) streams the per-block noise. + + ``op`` consumes its per-step driving signals (time, and for SDEs the noise + slice) from its block inputs and is agnostic to how they were produced. + """ + block_size = solver.block_size + window = solver.grad_horizon + if window is not None and block_size is not None: + window = _snap_window(window, block_size) + + # Streaming and/or fold both ride on the block scan, which requires + # block_size. The block step encapsulates the streaming-vs-presampled and + # stack-vs-fold choices; _composed_scan owns the window/block composition. + if block_size is not None and (fold is not None or noise_gen is not None): + if fold is not None: + acc0, update = fold + if noise_gen is not None: + block_step = _stream_fold_block(op, update, noise_gen) + carry0 = (state0, acc0, jnp.array(0)) + else: + block_step = _fold_block(op, update) + carry0 = (state0, acc0) + return _composed_scan( + block_step, carry0, scan_inputs, n_steps, block_size, window + ) + # streaming stack: unwrap the (state, counter) carry to the bare state. + block_step = _stream_block(op, noise_gen) + (state, _counter), outs = _composed_scan( + block_step, + (state0, jnp.array(0)), + scan_inputs, + n_steps, + block_size, + window, + ) + return state, outs + + # Non-streaming, non-fold paths (presampled noise or none). + if window is not None: + return _truncated_scan(op, state0, scan_inputs, n_steps, window, block_size) + if block_size is None: + return jax.lax.scan(op, state0, scan_inputs) + return _block_scan(op, state0, scan_inputs, n_steps, block_size) + + def _make_diffusion_matrix_fn(diffusion_fn, state_indices, n_states, n_nodes): """Build a vectorized diffusion-matrix closure for Diffrax SDE integration. @@ -393,6 +728,7 @@ def prepare( t0: float = 0.0, t1: float = 1.0, dt: float = 0.1, + reduce=None, ) -> Tuple[Callable, Bunch]: """Compile a model into a pure JAX solve function and a config PyTree. @@ -639,24 +975,12 @@ def update_all_external_states(external_state_dict, new_network_state): # ========================================================================= # VARIABLES OF INTEREST - Determine what to record # ========================================================================= - voi_indices = network.dynamics.get_variables_of_interest_indices() - n_states = network.dynamics.N_STATES - - # Split VOI indices into state and auxiliary indices - state_voi_indices = jnp.array([i for i in voi_indices if i < n_states], dtype=int) - aux_voi_indices = jnp.array( - [i - n_states for i in voi_indices if i >= n_states], dtype=int - ) - - # Flag: do we need to record any auxiliaries? - record_auxiliaries = len(aux_voi_indices) > 0 - - # Variable names that label axis 1 of the output trajectory. - # Ordering mirrors the concatenation below: selected states, then selected auxiliaries. - _all_variable_names = network.dynamics.all_variable_names - variable_names = tuple( - _all_variable_names[i] for i in voi_indices if i < n_states - ) + tuple(_all_variable_names[i] for i in voi_indices if i >= n_states) + ( + state_voi_indices, + aux_voi_indices, + record_auxiliaries, + variable_names, + ) = _split_voi(network.dynamics) # Static shape for the full per-call noise tensor. n_steps = len(time_steps) @@ -671,20 +995,30 @@ def _f(config): # be computed with gradient flow while avoiding per-step redundancy. enriched = precompute_all_couplings(config) - # Materialize the full noise trajectory once per call. Trace-time - # branch: if the injection slot is None, sample from config.noise.key - # in a single PRNG call (XLA fuses this with the downstream scan); - # otherwise use the user-provided tensor verbatim. This avoids the - # per-step PRNG cost of an in-scan fold_in while preserving the - # seed-scan API (vary config.noise.key per call) and the injection - # path (set config._internal.noise_samples). - if network.noise is not None: - if config._internal.noise_samples is None: - noise_samples_all = jax.random.normal( - config.noise.key, noise_samples_shape - ) - else: - noise_samples_all = config._internal.noise_samples + # Noise source. Streaming (per-block fold_in) activates when blocking is + # on, the network has noise, and no full tensor is injected: it skips the + # O(n_steps) draw and regenerates each block's noise in-scan (and on the + # backward pass). Otherwise the full tensor is drawn once and fused with + # the scan (or the injected tensor is used verbatim). + streaming = ( + network.noise is not None + and solver.block_size is not None + and config._internal.noise_samples is None + ) + noise_gen = None + if streaming: + noise_samples_all = None + noise_gen = _streaming_noise_gen( + config.noise.key, + (n_noise_states, n_nodes), + provider=config._internal.get("noise_provider", None), + ) + elif network.noise is not None: + noise_samples_all = _materialize_noise( + config.noise.key, + config._internal.noise_samples, + noise_samples_shape, + ) else: noise_samples_all = None @@ -693,18 +1027,16 @@ def op(state, inputs): Args: state: Bunch(dynamics=network_state, coupling=coupling_state_dict, external=external_state_dict) - inputs: (t, step_idx) for SDE or just t for ODE + inputs: (t, noise_slice) for SDE or just t for ODE Returns: (next_state, output) tuple for scan """ - # Unpack inputs + # Unpack per-step driving signals from the scan inputs. if network.noise is not None: - t = inputs[0] - step_idx = jnp.int32(inputs[1]) + t, noise = inputs else: t = inputs - step_idx = None # By default compute all coupling inputs ONCE per step at the # step-start point (t, state.dynamics) and freeze them across @@ -746,9 +1078,7 @@ def wrapped_dynamics(t_inner, network_state, params_dynamics): # Prepare noise sample if stochastic noise_sample = jnp.zeros_like(state.dynamics) if network.noise is not None: - # Single indexed read into the per-call noise tensor. - noise = noise_samples_all[step_idx] - + # ``noise`` is the per-step slice handed in via scan inputs. # Compute diffusion coefficient diffusion = network.noise.diffusion(t, state.dynamics, config.noise) @@ -785,42 +1115,46 @@ def wrapped_dynamics(t_inner, network_state, params_dynamics): ) # Apply VARIABLES_OF_INTEREST filtering to build output - # Collect selected state variables - if len(state_voi_indices) > 0: - selected_states = next_dynamics_state[state_voi_indices] - else: - selected_states = jnp.array([]).reshape(0, next_dynamics_state.shape[1]) - - # Collect selected auxiliary variables if needed - if record_auxiliaries and auxiliaries.size > 0: - selected_aux = auxiliaries[aux_voi_indices] - # Concatenate states and auxiliaries - output = jnp.concatenate([selected_states, selected_aux], axis=0) - else: - output = selected_states + output = _assemble_output( + next_dynamics_state, + auxiliaries, + state_voi_indices, + aux_voi_indices, + record_auxiliaries, + ) # Return (carry, output) return next_state, output - # Prepare scan inputs - if network.noise is None: - # ODE/DDE: just time + # Prepare scan inputs. The per-step driving signals are the scan xs: + # ODE/DDE (and streaming SDE) carry just time; presampled SDE carries + # (time, noise_slice). When streaming, ``op`` still consumes (time, + # noise) per step but the noise is generated per block in-scan via + # ``noise_gen`` rather than sliced from a global tensor. + if network.noise is None or streaming: scan_inputs = time_steps else: - # SDE/SDDE: time + step index for noise lookup - scan_inputs = jnp.stack([time_steps, jnp.arange(len(time_steps))], axis=1) - - # Run integration. When checkpoint_every is None we go through the - # original single-scan path verbatim to guarantee no performance - # regression for the default setting; otherwise switch to an - # outer-checkpointed nested scan that trades ~2x recompute on the - # backward pass for ~O(n_steps/K + K) backward memory. - if solver.checkpoint_every is None: - _, res = jax.lax.scan(op, state0, scan_inputs) - else: - _, res = _checkpointed_scan( - op, state0, scan_inputs, n_steps, solver.checkpoint_every - ) + scan_inputs = (time_steps, noise_samples_all) + + # Run integration through the single scan seam, which dispatches on + # the solver's block knob (block_size), the streaming noise source, + # and the reduce fold. + fold = _reduce_fold(reduce, variable_names, n_nodes, n_steps) + final_carry, res = run_scan( + op, state0, scan_inputs, n_steps, solver, fold=fold, noise_gen=noise_gen + ) + + # With a reducer, return the finalized aggregate rather than a + # trajectory. Blocked: acc is threaded in the (state, acc) carry. + # Monolithic (block_size=None): fold the stacked trajectory once (the + # degenerate single-block / post-hoc case, no memory win). + if reduce is not None: + _init, update, finalize = reduce + if solver.block_size is None: + acc = update(fold[0], res) + else: + acc = final_carry[1] + return finalize(acc) # Wrap result for consistency return wrap_native_result(res, t0, t1, dt, variable_names=variable_names) @@ -835,6 +1169,7 @@ def prepare( t0: float = 0.0, t1: float = 1.0, dt: float = 0.1, + reduce=None, ) -> Tuple[Callable, Bunch]: """Compile a model into a pure JAX solve function and a config PyTree. @@ -864,6 +1199,16 @@ def prepare( # VALIDATION: Check for unsupported features # ========================================================================= + # reduce is a native-only feature (it rides on the native block scan). + # plum dispatches on the first two positional args only, so a reduce= meant + # for the native path can land here; reject it explicitly rather than via a + # bare TypeError on an unexpected keyword. + if reduce is not None: + raise ValueError( + "reduce is only supported by NativeSolver, not DiffraxSolver " + "(it rides on the native block scan). Use a NativeSolver." + ) + # Check for delayed coupling (stateful) if network.max_delay > 0.0: raise ValueError( @@ -1157,6 +1502,7 @@ def prepare( n_nodes: int = 1, noise=None, externals=None, + reduce=None, ) -> Tuple[Callable, Bunch]: """Compile a model into a pure JAX solve function and a config PyTree. @@ -1274,21 +1620,14 @@ def update_all_external_states(external_state_dict, new_state): # References dynamics_fn = dynamics.dynamics solver_step = solver.step - n_states = dynamics.N_STATES # VOI filtering - voi_indices = dynamics.get_variables_of_interest_indices() - state_voi_indices = jnp.array([i for i in voi_indices if i < n_states], dtype=int) - aux_voi_indices = jnp.array( - [i - n_states for i in voi_indices if i >= n_states], dtype=int - ) - record_auxiliaries = len(aux_voi_indices) > 0 - - # Variable names matching the output layout: selected states, then selected auxiliaries. - _all_variable_names = dynamics.all_variable_names - variable_names = tuple( - _all_variable_names[i] for i in voi_indices if i < n_states - ) + tuple(_all_variable_names[i] for i in voi_indices if i >= n_states) + ( + state_voi_indices, + aux_voi_indices, + record_auxiliaries, + variable_names, + ) = _split_voi(dynamics) # Static shape for the full per-call noise tensor. noise_samples_shape = (n_steps, n_noise_states, n_nodes) if has_noise else None @@ -1296,23 +1635,38 @@ def update_all_external_states(external_state_dict, new_state): def _f(config): """Pure integration function for bare dynamics.""" - # Materialize the full noise trajectory once per call. See the - # network+native dispatch for the rationale; the trace-time branch - # on the injection slot keeps both seed-scan and explicit-replay - # workflows on the same scan body. - if has_noise: - if config._internal.noise_samples is None: - noise_samples_all = jax.random.normal( - config.noise.key, noise_samples_shape - ) - else: - noise_samples_all = config._internal.noise_samples + # Noise source. Streaming (per-block fold_in) activates when blocking is + # on, the dynamics has noise, and no full tensor is injected; otherwise + # the full tensor is drawn once (or the injected tensor used). See the + # network+native dispatch for the rationale. + streaming = ( + has_noise + and solver.block_size is not None + and config._internal.noise_samples is None + ) + noise_gen = None + if streaming: + noise_samples_all = None + noise_gen = _streaming_noise_gen( + config.noise.key, + (n_noise_states, n_nodes), + provider=config._internal.get("noise_provider", None), + ) + elif has_noise: + noise_samples_all = _materialize_noise( + config.noise.key, + config._internal.noise_samples, + noise_samples_shape, + ) else: noise_samples_all = None def op(carry, scan_input): - t = scan_input[0] - step_idx = scan_input[1].astype(int) + # Unpack per-step driving signals from the scan inputs. + if has_noise: + t, noise_raw = scan_input + else: + t = scan_input # Unpack carry if has_externals: @@ -1330,9 +1684,8 @@ def wrapped_dynamics(t_inner, s, params): ext_inputs = zero_external return dynamics_fn(t_inner, s, params, zero_coupling, ext_inputs) - # Noise: single indexed read into the per-call tensor. + # Noise: ``noise_raw`` is the per-step slice from the scan inputs. if has_noise: - noise_raw = noise_samples_all[step_idx] diffusion = noise_diffusion(t, state, config.noise) scaled_noise = diffusion * jnp.sqrt(dt) * noise_raw noise_sample = jnp.zeros_like(state) @@ -1345,16 +1698,13 @@ def wrapped_dynamics(t_inner, s, params): ) # VOI filtering - if len(state_voi_indices) > 0: - selected_states = next_state[state_voi_indices] - else: - selected_states = jnp.array([]).reshape(0, next_state.shape[1]) - - if record_auxiliaries and auxiliaries.size > 0: - selected_aux = auxiliaries[aux_voi_indices] - output = jnp.concatenate([selected_states, selected_aux], axis=0) - else: - output = selected_states + output = _assemble_output( + next_state, + auxiliaries, + state_voi_indices, + aux_voi_indices, + record_auxiliaries, + ) # Update carry if has_externals: @@ -1367,20 +1717,32 @@ def wrapped_dynamics(t_inner, s, params): return next_carry, output - scan_inputs = jnp.stack( - [time_steps, jnp.arange(n_steps, dtype=time_steps.dtype)], axis=1 - ) - # See the Network+Native dispatch for the rationale on the branch. - if solver.checkpoint_every is None: - _, res = jax.lax.scan(op, config.initial_state, scan_inputs) + # Per-step driving signals as the scan xs: time alone for ODE (and + # streaming SDE, whose noise is generated per block in-scan), or + # (time, noise_slice) for presampled SDE. + if not has_noise or streaming: + scan_inputs = time_steps else: - _, res = _checkpointed_scan( - op, - config.initial_state, - scan_inputs, - n_steps, - solver.checkpoint_every, - ) + scan_inputs = (time_steps, noise_samples_all) + # Single scan seam; dispatches on the solver's block knob, the streaming + # noise source, and the reduce fold. + fold = _reduce_fold(reduce, variable_names, n_nodes, n_steps) + final_carry, res = run_scan( + op, + config.initial_state, + scan_inputs, + n_steps, + solver, + fold=fold, + noise_gen=noise_gen, + ) + if reduce is not None: + _init, update, finalize = reduce + if solver.block_size is None: + acc = update(fold[0], res) + else: + acc = final_carry[1] + return finalize(acc) return wrap_native_result(res, t0, t1, dt, variable_names=variable_names) return _f, config @@ -1396,6 +1758,7 @@ def prepare( n_nodes: int = 1, noise=None, externals=None, + reduce=None, ) -> Tuple[Callable, Bunch]: """Compile a model into a pure JAX solve function and a config PyTree. @@ -1421,6 +1784,14 @@ def prepare( for bare dynamics) and Diffrax limitations (no delays, no auxiliaries, no VOI filtering). """ + # reduce is a native-only feature (it rides on the native block scan); plum + # dispatches on positional args only, so reject a stray reduce= explicitly. + if reduce is not None: + raise ValueError( + "reduce is only supported by NativeSolver, not DiffraxSolver " + "(it rides on the native block scan). Use a NativeSolver." + ) + # Initial state [N_STATES, n_nodes] state0 = dynamics.get_default_initial_state(n_nodes) diff --git a/src/tvboptim/experimental/network_dynamics/solvers/native.py b/src/tvboptim/experimental/network_dynamics/solvers/native.py index 36cd058..1dba1be 100644 --- a/src/tvboptim/experimental/network_dynamics/solvers/native.py +++ b/src/tvboptim/experimental/network_dynamics/solvers/native.py @@ -26,8 +26,9 @@ class NativeSolver(AbstractSolver): def __init__( self, - checkpoint_every: int | None = None, + block_size: int | None = None, recompute_coupling_per_stage: bool = False, + grad_horizon: int | None = None, ): """ Args: @@ -52,35 +53,79 @@ def __init__( External inputs are always evaluated per stage regardless of this flag. - checkpoint_every: If None (default), no gradient checkpointing — - the integration scan runs as a single ``jax.lax.scan`` and - every step's carry is saved for the backward pass. If an int - ``K``, the scan is split into an outer scan over blocks of - ``K`` steps wrapped in ``jax.checkpoint``, with an inner - scan running the steps inside each block. The backward pass - then only retains block-boundary carries and recomputes - inner activations on demand. Trades roughly 1.3–1.7x - gradient wall time (one extra forward recompute, added to - an already backward-dominated cost) for - ``O(n_steps/K + K)`` backward memory instead of - ``O(n_steps)``. Peak memory is U-shaped in ``K``: small - ``K`` inflates the outer block-boundary tape - (``n_steps/K`` term), large ``K`` inflates the per-block - inner tape (``K`` term). The minimum sits near - ``K ≈ sqrt(n_steps)``. Has no effect on the forward-only - path. + block_size: The granularity of the nested block scan, and the + one block unit for the streaming features. If None (default), + the integration runs as a single ``jax.lax.scan`` and every + step's carry is saved for the backward pass (the monolithic + path; no blocking). If an int ``K``, the scan is split into an + outer scan over blocks of ``K`` steps each wrapped in + ``jax.checkpoint``, with an inner scan running the steps inside + a block. The backward pass then only retains block-boundary + carries and recomputes inner activations on demand. Trades + roughly 1.3-1.7x gradient wall time (one extra forward + recompute, added to an already backward-dominated cost) for + ``O(n_steps/K + K)`` backward memory instead of ``O(n_steps)``. + Peak memory is U-shaped in ``K``: small ``K`` inflates the + outer block-boundary tape (``n_steps/K`` term), large ``K`` + inflates the per-block inner tape (``K`` term). The minimum + sits near ``K = sqrt(n_steps)``. Has no effect on the + forward-only path. + + This same block is the unit that carries per-block streaming + noise and the online ``reduce`` accumulator when those per-call + features are used (so blocking *is* checkpointing; the + checkpoint grain is ``block_size``). ``block_size`` was formerly + named ``checkpoint_every``; once streaming noise is active a + blocked run reseeds the noise relative to the monolithic path + (see the Phase 4 plan). The memory model assumes a per-step carry whose size does not grow with ``n_steps``. This holds for the ``roll`` and ``circular`` delayed-coupling buffer strategies (history buffer size = ``max_delay_steps + 1``), but **not** for ``preallocated`` (history buffer grows linearly with - ``n_steps``). Checkpointing still works correctly with + ``n_steps``). Blocking still works correctly with ``preallocated``, but the practical memory win is much smaller because the carry itself dominates. + grad_horizon: If None (default), the gradient is exact over the + whole rollout. If an int ``W``, the integration runs as an + outer scan over windows of ``W`` steps with the carry + gradient severed (``jax.lax.stop_gradient``) at each window + boundary: truncated backpropagation through time. Credit is + assigned only within a window, so the backward ``T`` in the + ``exp(T·Lambda)`` sensitivity growth becomes ``W·dt`` instead + of the full horizon. The forward rollout is unchanged and + bit-exact; only the backward pass differs. + + ``grad_horizon`` and ``block_size`` are independent and + nest. ``block_size`` is a *memory* granularity (how big a + block to rematerialize on the backward pass, set by available + RAM); ``grad_horizon`` is the *gradient horizon* (how far back + credit is assigned, set by the slowest timescale of interest). + The two block sizes are not the same number. When both are set, + ``block_size`` tiles ``grad_horizon`` (each window is + rematerialized as ``W / block_size`` blocks); the clean case is + ``W % block_size == 0`` and ``n_steps % W == 0``, with + remainders handled by a tail scan. + + Severing the carry bounds the gradient horizon but does **not** + by itself bound activation memory: the full output is still + stacked and the loss taped over the whole rollout. For a long + rollout that needs both a bounded horizon and bounded memory, + set both knobs (``grad_horizon`` for the horizon, + ``block_size`` for memory within it). ``grad_horizon`` set + with ``block_size=None`` is a gradient-stability fix only. + + TBPTT biases toward fast timescales: any credit assignment + longer than ``W·dt`` is zeroed, so too short a window can null + the gradient of genuinely slow parameters. The window is a real + hyperparameter. For SDEs it is well-defined because noise is + fixed per call (common random numbers), so each window is a + deterministic map given its noise. """ - self.checkpoint_every = checkpoint_every + self.block_size = block_size self.recompute_coupling_per_stage = recompute_coupling_per_stage + self.grad_horizon = grad_horizon def step( self, @@ -325,9 +370,10 @@ def __init__( low: float | jnp.ndarray = -jnp.inf, high: float | jnp.ndarray = jnp.inf, ): - # Deliberately skip NativeSolver.__init__ — checkpoint_every is - # delegated to base_solver via the property below so that wrapping - # a checkpointed solver does not silently lose the setting. + # Deliberately skip NativeSolver.__init__: block_size, + # recompute_coupling_per_stage and grad_horizon are delegated to + # base_solver via the properties below so that wrapping a solver does + # not silently lose those settings. self.base_solver = base_solver low = jnp.asarray(low) high = jnp.asarray(high) @@ -335,13 +381,17 @@ def __init__( self.high = high[:, None] if high.ndim == 1 else high @property - def checkpoint_every(self): - return self.base_solver.checkpoint_every + def block_size(self): + return self.base_solver.block_size @property def recompute_coupling_per_stage(self): return self.base_solver.recompute_coupling_per_stage + @property + def grad_horizon(self): + return self.base_solver.grad_horizon + def step( self, dynamics_fn: Callable, diff --git a/src/tvboptim/observations/__init__.py b/src/tvboptim/observations/__init__.py index 0c7b8ec..d42e901 100644 --- a/src/tvboptim/observations/__init__.py +++ b/src/tvboptim/observations/__init__.py @@ -6,6 +6,7 @@ ks_distance, rmse, wasserstein_1d, + welford_cov, ) __all__ = [ @@ -16,4 +17,5 @@ "ks_distance", "rmse", "wasserstein_1d", + "welford_cov", ] diff --git a/src/tvboptim/observations/observation.py b/src/tvboptim/observations/observation.py index 1cb687d..5fc4c1d 100644 --- a/src/tvboptim/observations/observation.py +++ b/src/tvboptim/observations/observation.py @@ -35,6 +35,57 @@ def compute_fc(timeseries, s_var=0, mode=0, skip_t=0): return _fc.at[jnp.diag_indices(_fc.shape[0])].set(0) +def welford_cov(s_var=0): + """Online functional-connectivity reducer for the network solver. + + Returns a ``(init, update, finalize)`` triple for the ``reduce=`` kwarg of + ``prepare`` / ``solve``. It maintains a running mean and co-moment of the + chosen state variable's region time series via a block-wise Welford / Chan + merge (one batched ``X^T X`` per block, not a per-step rank-1 update), so + the accumulator is ``O(N^2)`` in the region count ``N`` and independent of + ``n_steps``. ``finalize`` returns the correlation matrix with a zeroed + diagonal, matching ``compute_fc(result, s_var=s_var)`` on the full + trajectory (the equivalence reference). + + Args: + s_var: Index into the variables-of-interest axis (axis 1 of the stacked + trajectory), matching ``compute_fc``'s ``s_var``. + """ + + def init(template, n_steps): + # template is one step's output [n_vois, n_nodes]; size from the + # region (node) axis. n_steps is unused (the state is O(1) in time). + n = template.shape[-1] + return (jnp.array(0.0), jnp.zeros(n), jnp.zeros((n, n))) + + def update(acc, block): + # block is [block_len, n_vois, n_nodes]; pick the chosen variable's + # region series and merge its block mean / co-moment into the running + # accumulator (Chan's parallel formula, exact up to float error). + count, mean, comoment = acc + x = block[:, s_var, :] + nb = x.shape[0] + mean_b = jnp.mean(x, axis=0) + xc = x - mean_b + comoment_b = xc.T @ xc + delta = mean_b - mean + new_count = count + nb + new_mean = mean + delta * (nb / new_count) + new_comoment = ( + comoment + comoment_b + jnp.outer(delta, delta) * (count * nb / new_count) + ) + return (new_count, new_mean, new_comoment) + + def finalize(acc): + count, _mean, comoment = acc + cov = comoment / count + d = jnp.sqrt(jnp.diag(cov)) + corr = cov / jnp.outer(d, d) + return corr.at[jnp.diag_indices(corr.shape[0])].set(0.0) + + return (init, update, finalize) + + def compute_fcd(timeseries, t_window, step_size, s_var=0, mode=0, skip_t=0): """Compute the functional connectivity dynamics (FCD) matrix. diff --git a/src/tvboptim/observations/tvb_monitors/__init__.py b/src/tvboptim/observations/tvb_monitors/__init__.py index 98679ea..bcf2008 100644 --- a/src/tvboptim/observations/tvb_monitors/__init__.py +++ b/src/tvboptim/observations/tvb_monitors/__init__.py @@ -8,6 +8,7 @@ HRFKernel, LotkaVolterraHRFKernel, MixtureOfGammasHRFKernel, + streaming_hrf_bold, ) from .downsampling import AbstractMonitor, SubSampling, TemporalAverage @@ -16,9 +17,13 @@ "SubSampling", "TemporalAverage", "HRFBold", + "streaming_hrf_bold", "BalloonWindkesselBold", "FirstOrderVolterraHRFKernel", "HRFKernel", + "DoubleExponentialHRFKernel", + "GammaHRFKernel", + "MixtureOfGammasHRFKernel", # Deprecated aliases "Bold", "LotkaVolterraHRFKernel", diff --git a/src/tvboptim/observations/tvb_monitors/bold.py b/src/tvboptim/observations/tvb_monitors/bold.py index 03eb545..21565bb 100644 --- a/src/tvboptim/observations/tvb_monitors/bold.py +++ b/src/tvboptim/observations/tvb_monitors/bold.py @@ -19,8 +19,10 @@ def _bold_variable_names(sol, voi=None): monitor does its own voi selection instead of delegating to a downsampler). Returns None if the source has no variable_names. """ - names = _slice_variable_names(sol, voi) if voi is not None else getattr( - sol, "variable_names", None + names = ( + _slice_variable_names(sol, voi) + if voi is not None + else getattr(sol, "variable_names", None) ) if names is None: return None @@ -142,6 +144,7 @@ def __call__(self, t: jax.Array, downsample_dt: float) -> jax.Array: / omega ) + class GammaHRFKernel(HRFKernel): """ Gamma HRF kernel, ported from TVBSim's Gamma class. @@ -166,7 +169,7 @@ class GammaHRFKernel(HRFKernel): J Neurosci 16: 4207-4221. """ - tau: float = 1.08 # seconds + tau: float = 1.08 # seconds n: float = 3.0 a: float = 0.1 duration: float = 20_000.0 # ms @@ -177,19 +180,19 @@ def __call__(self, t: jax.Array, downsample_dt: float) -> jax.Array: factorial = math.factorial(int(self.n) - 1) - kernel = ( - (t_s / self.tau) ** (self.n - 1) - * jnp.exp(-(t_s / self.tau)) - ) / (self.tau * factorial) + kernel = ((t_s / self.tau) ** (self.n - 1) * jnp.exp(-(t_s / self.tau))) / ( + self.tau * factorial + ) # Replicate TVBSim's normalization and amplitude scaling from evaluate() peak = jnp.max(kernel) - peak = jnp.where(peak > 0, peak, 1.0) # Avoid division by zero + peak = jnp.where(peak > 0, peak, 1.0) # Avoid division by zero kernel = kernel / peak kernel = kernel * self.a return kernel + class DoubleExponentialHRFKernel(HRFKernel): """ A difference of two exponential functions to define a kernel for the bold monitor, ported from TVBSim's DoubleExponential class. @@ -216,37 +219,44 @@ class DoubleExponentialHRFKernel(HRFKernel): Reference --------- - Alex Polonsky, Randolph Blake, Jochen Braun and David J. Heeger + Alex Polonsky, Randolph Blake, Jochen Braun and David J. Heeger (2000). Neuronal activity in human primary visual cortex correlates with perception during binocular rivalry. Nature Neuroscience 3: 1153-1159 """ - tau_1: float = 7.22 + tau_1: float = 7.22 tau_2: float = 7.4 f_1: float = 0.03 - f_2: float = 0.12 + f_2: float = 0.12 amp_1: float = 0.1 amp_2: float = 0.1 - a: float = 0.1 + a: float = 0.1 duration: float = 40_000.0 # ms def __call__(self, t: jax.Array, downsample_dt: float) -> jax.Array: # Convert ms to seconds t_s = t / 1000.0 - - kernel = ((self.amp_1 * jnp.exp(-t_s/self.tau_1) * jnp.sin(2 * math.pi * self.f_1 * t_s)) - - (self.amp_2 * jnp.exp(-t_s/self.tau_2) * jnp.sin(2 * math.pi * self.f_2 * t_s)) - ) + + kernel = ( + self.amp_1 + * jnp.exp(-t_s / self.tau_1) + * jnp.sin(2 * math.pi * self.f_1 * t_s) + ) - ( + self.amp_2 + * jnp.exp(-t_s / self.tau_2) + * jnp.sin(2 * math.pi * self.f_2 * t_s) + ) # Replicate TVBSim's normalization + amplitude scaling from evaluate() peak = jnp.max(kernel) - peak = jnp.where(peak > 0, peak, 1.0) # Avoid division by zero + peak = jnp.where(peak > 0, peak, 1.0) # Avoid division by zero kernel = kernel / peak kernel = kernel * self.a return kernel + class MixtureOfGammasHRFKernel(HRFKernel): """ Mixture of two gamma distributions HRF kernel, ported from TVBSim's MixtureOfGammas. @@ -285,13 +295,12 @@ def __call__(self, t: jax.Array, downsample_dt: float) -> jax.Array: gamma_a_1 = jsp.special.gamma(self.a_1) gamma_a_2 = jsp.special.gamma(self.a_2) - return ( - (self.l * t_s) ** (self.a_1 - 1) * jnp.exp(-self.l * t_s) / gamma_a_1 - - self.c - * (self.l * t_s) ** (self.a_2 - 1) - * jnp.exp(-self.l * t_s) - / gamma_a_2 - ) + return (self.l * t_s) ** (self.a_1 - 1) * jnp.exp( + -self.l * t_s + ) / gamma_a_1 - self.c * (self.l * t_s) ** (self.a_2 - 1) * jnp.exp( + -self.l * t_s + ) / gamma_a_2 + def LotkaVolterraHRFKernel(*args, **kwargs): """Deprecated: use FirstOrderVolterraHRFKernel. @@ -498,6 +507,101 @@ def convolve_single(x): ) +def streaming_hrf_bold(monitor, dt): + """Block-level streaming reducer form of :class:`HRFBold`. + + Returns an ``(init, update, finalize)`` triple for the ``reduce=`` kwarg of + ``prepare`` / ``solve``, computing the same BOLD signal as + ``monitor(full_solution)`` without ever stacking the full neural trajectory. + It reuses the monitor's kernel, periods and BOLD scaling, and carries a + downsampled-history ring buffer plus a preallocated BOLD output buffer; each + block subsamples, convolves ``[ring; block_downsampled]`` with the HRF + (``valid`` mode, same ``fftconvolve`` as the post-hoc monitor), and writes + the BOLD samples at TR boundaries into the buffer. + + Requirements (so blocks align with the decimation and BOLD grids): + + - the monitor must use a ``SubSampling`` downsample (uniform integer stride); + ``TemporalAverage``'s float-rounded windows are not streamable. + - ``block_size`` and ``n_steps`` must be multiples of the BOLD period in raw + steps (``period / dt``); the per-block update asserts this. + + Warm start / chaining: ``init`` seeds the ring from ``monitor.history`` when + set (the same warm-start the monitor accepts); ``finalize`` returns the BOLD + buffer ``[n_bold, n_voi, n_nodes]``. + """ + ds_period = monitor.downsample_period + period = monitor.period + voi = monitor.voi + k_1, V_0 = monitor.k_1, monitor.V_0 + conv_mode = monitor.convolution_mode + kernel = monitor.kernel + warm_history = monitor.history + + ds_steps = int(round(ds_period / dt)) + final_idx_step = int(round(period / ds_period)) + period_in_steps = ds_steps * final_idx_step + kernel_samples = int(math.ceil(kernel.duration / ds_period)) + hrf = kernel(jnp.linspace(0.0, kernel.duration, kernel_samples), ds_period) + + def _conv_valid(signal): + # signal [time, n_voi, n_nodes] -> valid HRF convolution along time, + # vectorized over nodes and variables exactly as HRFBold.__call__. + def convolve_single(x): + return jsp.signal.fftconvolve(x, hrf, mode=conv_mode) + + return jax.vmap( + lambda y: jax.vmap(convolve_single, in_axes=1, out_axes=1)(y), + in_axes=1, + out_axes=1, + )(signal) + + def init(template, n_steps): + t_sel = template[voi, :] # [n_voi, n_nodes] + n_voi, n_nodes = t_sel.shape + # SubSampling emits these indices; n_bold matches HRFBold's bold_indices. + n_ds = len(range(ds_steps - 1, n_steps, ds_steps)) + n_bold = len(range(final_idx_step, n_ds + 1, final_idx_step)) + ring0 = ( + jnp.zeros((kernel_samples, n_voi, n_nodes)) + if warm_history is None + else jnp.asarray(warm_history) + ) + bold0 = jnp.zeros((n_bold, n_voi, n_nodes)) + return (ring0, bold0, jnp.array(0)) + + def update(acc, block): + ring, bold_buffer, ds_count = acc + y = block[:, voi, :] # [block_len, n_voi, n_nodes] + block_len = y.shape[0] + assert block_len % period_in_steps == 0, ( + "streaming_hrf_bold requires each block length to be a multiple of " + f"the BOLD period in steps ({period_in_steps} = period/dt); got " + f"{block_len}. Set block_size and n_steps to multiples of period/dt." + ) + block_ds = y[ds_steps - 1 :: ds_steps] # SubSampling, [m_b, n_voi, n_nodes] + m_b = block_ds.shape[0] + signal = jnp.concatenate([ring, block_ds], axis=0) + conv = _conv_valid(signal) # [m_b + 1, n_voi, n_nodes] (valid) + conv = k_1 * V_0 * (conv - 1.0) + # BOLD samples at TR boundaries within this block (block-aligned so the + # output slots are contiguous). conv[i] == B_full[ds_count + i]. + idx = jnp.arange(final_idx_step, m_b + 1, final_idx_step) + bold_samples = conv[idx] + start = ds_count // final_idx_step + bold_buffer = jax.lax.dynamic_update_slice( + bold_buffer, bold_samples, (start,) + (0,) * (bold_buffer.ndim - 1) + ) + ring = signal[-kernel_samples:] + return (ring, bold_buffer, ds_count + m_b) + + def finalize(acc): + _ring, bold_buffer, _ds_count = acc + return bold_buffer + + return (init, update, finalize) + + class BalloonWindkesselBold(AbstractMonitor): """BOLD signal monitor using Balloon-Windkessel hemodynamic ODE. diff --git a/src/tvboptim/optim/callbacks.py b/src/tvboptim/optim/callbacks.py index 302398e..a17da6e 100644 --- a/src/tvboptim/optim/callbacks.py +++ b/src/tvboptim/optim/callbacks.py @@ -62,13 +62,9 @@ def __init__(self, every=1, *args: None) -> None: super().__init__( every, key="loss", - save_fun=lambda i, - diff_state, - static_state, - fitting_data, - aux_data, - loss_value, - grads: loss_value, + save_fun=lambda i, diff_state, static_state, fitting_data, aux_data, loss_value, grads: ( + loss_value + ), ) @@ -77,13 +73,9 @@ def __init__(self, every=1, *args: None) -> None: super().__init__( every, key="parameters", - save_fun=lambda i, - diff_state, - static_state, - fitting_data, - aux_data, - loss_value, - grads: diff_state, + save_fun=lambda i, diff_state, static_state, fitting_data, aux_data, loss_value, grads: ( + diff_state + ), ) diff --git a/tests/test_identifiability.py b/tests/test_identifiability.py index d840fc5..8f3156b 100644 --- a/tests/test_identifiability.py +++ b/tests/test_identifiability.py @@ -153,7 +153,9 @@ def test_eigendecompose_curvature_directly(self): # eigh of [[2,0],[0,8]] -> eigenvalues [2, 8]. res = eigendecompose_curvature( jnp.array([[2.0, 0.0], [0.0, 8.0]]), - labels=["p", "q"], theta0=jnp.zeros(2), kind="hessian", + labels=["p", "q"], + theta0=jnp.zeros(2), + kind="hessian", ) np.testing.assert_allclose(np.asarray(res.eigenvalues), [2.0, 8.0]) self.assertEqual(res.condition_number(), 4.0) @@ -199,9 +201,7 @@ def test_sigma_scaling(self): _, _, residual, _, state = self._linear_problem() FIM1, _, _ = fisher_information(residual, state, sigma=1.0) FIM2, _, _ = fisher_information(residual, state, sigma=2.0) - np.testing.assert_allclose( - np.asarray(FIM2), np.asarray(FIM1) / 4.0, atol=1e-9 - ) + np.testing.assert_allclose(np.asarray(FIM2), np.asarray(FIM1) / 4.0, atol=1e-9) def test_per_observation_sigma(self): M, _, residual, _, state = self._linear_problem() @@ -214,9 +214,7 @@ def test_fwd_and_rev_modes_agree(self): _, _, residual, _, state = self._linear_problem() FIM_fwd, _, _ = fisher_information(residual, state, mode="fwd") FIM_rev, _, _ = fisher_information(residual, state, mode="rev") - np.testing.assert_allclose( - np.asarray(FIM_fwd), np.asarray(FIM_rev), atol=1e-9 - ) + np.testing.assert_allclose(np.asarray(FIM_fwd), np.asarray(FIM_rev), atol=1e-9) def test_scalar_model_warns(self): _, _, _, _, state = self._linear_problem() @@ -238,7 +236,6 @@ def model(state): res = eigendecompose_curvature(FIM, labels, theta0, kind="fisher") self.assertEqual(res.rank(), 1) - flat = res.sloppy_directions(1)[0] ia, ib = res.labels.index("a"), res.labels.index("b") vec = np.array([res.eigenvectors[ia, 0], res.eigenvectors[ib, 0]]) vec /= np.linalg.norm(vec) @@ -345,9 +342,7 @@ def loss_theta(theta): # The FIM and the FD loss Hessian must share the flat direction. v_fim = np.asarray(res.eigenvectors[:, 0]) v_fd = evecs_fd[:, 0] - cos = abs( - v_fim @ v_fd / (np.linalg.norm(v_fim) * np.linalg.norm(v_fd)) - ) + cos = abs(v_fim @ v_fd / (np.linalg.norm(v_fim) * np.linalg.norm(v_fd))) self.assertGreater(cos, 0.95) # The amplitude / excitability degeneracy is a genuine sloppy ridge. diff --git a/tests/test_network_dynamics/test_adiabatic_scan.py b/tests/test_network_dynamics/test_adiabatic_scan.py new file mode 100644 index 0000000..e5804db --- /dev/null +++ b/tests/test_network_dynamics/test_adiabatic_scan.py @@ -0,0 +1,117 @@ +"""Test the adiabatic_scan analysis helper across model/branch combinations.""" + +import jax +import jax.numpy as jnp +import numpy as np +import pytest + +jax.config.update("jax_enable_x64", True) + +from tvboptim.experimental.network_dynamics import Network +from tvboptim.experimental.network_dynamics.analysis import ( + AdiabaticScanResult, + adiabatic_scan, +) +from tvboptim.experimental.network_dynamics.coupling import LinearCoupling +from tvboptim.experimental.network_dynamics.dynamics.tvb import ( + JansenRit, + ReducedWongWang, +) +from tvboptim.experimental.network_dynamics.graph import DenseGraph + +# (model class, coupling state variable) — kept generic/parametrized +MODELS = [ + (ReducedWongWang, "S"), + (JansenRit, "y1"), +] + + +def _build_network(model_class, coupling_var, n_nodes=4, seed=0): + graph = DenseGraph.random(n_nodes=n_nodes, key=jax.random.PRNGKey(seed)) + coupling = LinearCoupling(incoming_states=coupling_var, G=0.1) + return Network( + dynamics=model_class(), + coupling={"instant": coupling}, + graph=graph, + ) + + +@pytest.mark.parametrize("model_class,coupling_var", MODELS) +@pytest.mark.parametrize("bothways", [True, False]) +def test_adiabatic_scan_shapes_and_finiteness(model_class, coupling_var, bothways): + network = _build_network(model_class, coupling_var) + n = 4 + res = adiabatic_scan( + network, + accessor=lambda c: c.coupling.instant.G, + low=0.0, + high=0.5, + n=n, + t=20.0, + skip=10.0, + dt=1.0, + bothways=bothways, + ) + + assert isinstance(res, AdiabaticScanResult) + assert res.n_up == n + expected_len = 2 * n if bothways else n + assert len(res.p) == expected_len + + # default statistics present, aligned with p, and finite + assert set(res.stats) == {"mean", "min", "max"} + # Bunch allows attribute access alongside key access + np.testing.assert_array_equal(res.stats.mean, res.stats["mean"]) + for arr in res.stats.values(): + assert arr.shape == (expected_len,) + assert np.all(np.isfinite(arr)) + + # min <= mean <= max at every scan point + assert np.all(res.stats["min"] <= res.stats["mean"] + 1e-8) + assert np.all(res.stats["mean"] <= res.stats["max"] + 1e-8) + + # ascending branch sweeps the requested bounds + np.testing.assert_allclose(res.p[0], 0.0, atol=1e-12) + np.testing.assert_allclose(res.p[n - 1], 0.5, atol=1e-12) + + +def test_adiabatic_scan_custom_observe_and_statistics(): + network = _build_network(ReducedWongWang, "S") + res = adiabatic_scan( + network, + accessor=lambda c: c.coupling.instant.G, + low=0.0, + high=0.3, + n=3, + t=20.0, + skip=10.0, + dt=1.0, + bothways=False, + observe=lambda result: result.ys[:, 0, :], + statistics={"std": lambda arr: jnp.std(arr.mean(axis=0))}, + ) + assert set(res.stats) == {"std"} + assert res.stats["std"].shape == (3,) + assert np.all(np.isfinite(res.stats["std"])) + + +def test_adiabatic_scan_vector_valued_statistic(): + n_nodes = 4 + network = _build_network(ReducedWongWang, "S", n_nodes=n_nodes) + n = 3 + res = adiabatic_scan( + network, + accessor=lambda c: c.coupling.instant.G, + low=0.0, + high=0.3, + n=n, + t=20.0, + skip=10.0, + dt=1.0, + bothways=False, + # per-node temporal mean -> [n_nodes], stacked to [len(p), n_nodes] + statistics={"per_node": lambda arr: arr.mean(axis=0)}, + ) + assert set(res.stats) == {"per_node"} + assert res.stats["per_node"].shape == (n, n_nodes) + assert np.all(np.isfinite(res.stats["per_node"])) diff --git a/tests/test_network_dynamics/test_basic_networks.py b/tests/test_network_dynamics/test_basic_networks.py index 70a1028..4d534ba 100644 --- a/tests/test_network_dynamics/test_basic_networks.py +++ b/tests/test_network_dynamics/test_basic_networks.py @@ -23,7 +23,7 @@ from tvboptim.experimental.network_dynamics.graph import DenseDelayGraph, DenseGraph from tvboptim.experimental.network_dynamics.noise import AdditiveNoise from tvboptim.experimental.network_dynamics.solve import prepare -from tvboptim.experimental.network_dynamics.solvers import Heun +from tvboptim.experimental.network_dynamics.solvers import BoundedSolver, Heun class TestBasicNetworks(unittest.TestCase): @@ -283,9 +283,7 @@ def test_noise_injection_matches_key_path(self): n_steps = int(round((5.0 - 0.0) / 0.1)) n_noise_states = len(network.noise._state_indices) n_nodes = network.graph.n_nodes - samples = jax.random.normal( - cfg.noise.key, (n_steps, n_noise_states, n_nodes) - ) + samples = jax.random.normal(cfg.noise.key, (n_steps, n_noise_states, n_nodes)) cfg_inj = eqx.tree_at( lambda c: c._internal.noise_samples, cfg, @@ -438,6 +436,12 @@ class TestCheckpointedScan(unittest.TestCase): """Verify that the block-checkpointed scan path matches the unchecked single-scan path bit-exactly on forward and to numerical precision on backward, for both divisor and non-divisor block sizes. + + The network is deterministic (no noise) so ``block_size`` is pure + rematerialization here: it does NOT stream noise, so the blocked result is + bit-exact to the monolithic one. The streaming-noise behaviour that + ``block_size`` also triggers for an SDE network (which reseeds, so it is + deliberately not bit-exact to monolithic) is covered by ``TestStreamingNoise``. """ def _build_dde_network(self): @@ -452,14 +456,14 @@ def _build_dde_network(self): dynamics=ReducedWongWang(), coupling={"delayed": coupling}, graph=graph, - noise=AdditiveNoise(sigma=1e-4, key=jax.random.key(0)), + noise=None, ) - def _run(self, checkpoint_every, t1=20.0, dt=0.1): + def _run(self, block_size, t1=20.0, dt=0.1): network = self._build_dde_network() solve_fn, cfg = prepare( network, - Heun(checkpoint_every=checkpoint_every), + Heun(block_size=block_size), t0=0.0, t1=t1, dt=dt, @@ -491,8 +495,8 @@ def test_gradient_matches_baseline(self): unckecpointed gradient to numerical precision, for both divisor and non-divisor block sizes.""" - def make_loss(checkpoint_every): - solve_fn, cfg = self._run(checkpoint_every) + def make_loss(block_size): + solve_fn, cfg = self._run(block_size) def loss(G): import equinox as eqx @@ -526,21 +530,19 @@ def loss(G): def test_default_is_none(self): """Sentinel: the default constructor must not enable checkpointing. Guards the no-perf-regression contract for existing call sites.""" - self.assertIsNone(Heun().checkpoint_every) + self.assertIsNone(Heun().block_size) def test_bare_dynamics_dispatch(self): - """Bare-dynamics+native path also branches on checkpoint_every.""" + """Bare-dynamics+native path also branches on block_size.""" from tvboptim.experimental.network_dynamics.dynamics.tvb import ( ReducedWongWang, ) dyn = ReducedWongWang() - solve_none, cfg = prepare( - dyn, Heun(), t0=0.0, t1=10.0, dt=0.1, n_nodes=3 - ) + solve_none, cfg = prepare(dyn, Heun(), t0=0.0, t1=10.0, dt=0.1, n_nodes=3) solve_ckpt, _ = prepare( dyn, - Heun(checkpoint_every=11), # non-divisor of 100 steps + Heun(block_size=11), # non-divisor of 100 steps t0=0.0, t1=10.0, dt=0.1, @@ -580,9 +582,7 @@ def _build_cases(self): Network( dynamics=ReducedWongWang(), coupling={ - "delayed": DelayedLinearCoupling( - incoming_states=["S"], G=0.3 - ) + "delayed": DelayedLinearCoupling(incoming_states=["S"], G=0.3) }, graph=DenseDelayGraph(weights, delays), ), @@ -626,8 +626,9 @@ def test_config_mutation_does_not_leak(self): f"{name}: re-prepare returned mutated value", ) self.assertTrue( - jnp.array_equal(jnp.asarray(cfg2.dynamics.w), - jnp.asarray(original_w)), + jnp.array_equal( + jnp.asarray(cfg2.dynamics.w), jnp.asarray(original_w) + ), f"{name}: re-prepare did not restore original w", ) @@ -637,5 +638,562 @@ def test_config_mutation_does_not_leak(self): jax.block_until_ready(jax.jit(solve_fn)(cfg3)) +class TestSolveHelpers(unittest.TestCase): + """Pin the contract of the Phase 1 module-level helpers extracted from the + duplicated native dispatch bodies. These are covered transitively by the + bit-exact integration suite; this class documents the helpers directly so a + future refactor of one cannot silently drift from the inline logic it + replaced. + """ + + DYNAMICS = [JansenRit(), ReducedWongWang()] + + def test_split_voi_matches_inline_reference(self): + from tvboptim.experimental.network_dynamics.solve import _split_voi + + for dyn in self.DYNAMICS: + with self.subTest(dynamics=type(dyn).__name__): + # Reference: the exact inline logic the helper replaced. + voi = dyn.get_variables_of_interest_indices() + n_states = dyn.N_STATES + names = dyn.all_variable_names + ref_state = jnp.array([i for i in voi if i < n_states], dtype=int) + ref_aux = jnp.array( + [i - n_states for i in voi if i >= n_states], dtype=int + ) + ref_record = len(ref_aux) > 0 + ref_names = tuple(names[i] for i in voi if i < n_states) + tuple( + names[i] for i in voi if i >= n_states + ) + + state_idx, aux_idx, record, var_names = _split_voi(dyn) + self.assertTrue(jnp.array_equal(state_idx, ref_state)) + self.assertTrue(jnp.array_equal(aux_idx, ref_aux)) + self.assertEqual(record, ref_record) + self.assertEqual(var_names, ref_names) + # Labels match the number of recorded rows. + self.assertEqual(len(var_names), len(ref_state) + len(ref_aux)) + + def test_materialize_noise_draw_and_injection(self): + from tvboptim.experimental.network_dynamics.solve import _materialize_noise + + shape = (5, 2, 3) + key = jax.random.key(0) + # Default provider: single fused draw, reproducible from the key. + drawn = _materialize_noise(key, None, shape) + self.assertEqual(drawn.shape, shape) + self.assertTrue(jnp.array_equal(drawn, jax.random.normal(key, shape))) + # Injection: passed through verbatim, ignoring the key. + injected = jnp.ones(shape) + out = _materialize_noise(key, injected, shape) + self.assertIs(out, injected) + + def test_assemble_output_layout(self): + from tvboptim.experimental.network_dynamics.solve import _assemble_output + + n_nodes = 3 + next_state = jnp.arange(4 * n_nodes, dtype=float).reshape(4, n_nodes) + aux = jnp.arange(2 * n_nodes, dtype=float).reshape(2, n_nodes) + 100.0 + + # States only. + out = _assemble_output( + next_state, aux, jnp.array([0, 2]), jnp.array([], dtype=int), False + ) + self.assertTrue(jnp.array_equal(out, next_state[jnp.array([0, 2])])) + + # States followed by selected auxiliaries. + out = _assemble_output(next_state, aux, jnp.array([1]), jnp.array([0]), True) + expected = jnp.concatenate( + [next_state[jnp.array([1])], aux[jnp.array([0])]], axis=0 + ) + self.assertTrue(jnp.array_equal(out, expected)) + + +class TestTruncatedScan(unittest.TestCase): + """Verify the truncated-BPTT windowed scan (`_truncated_scan`). + + Forward must stay bit-exact to a plain single scan (``stop_gradient`` is the + identity on the forward); the backward must (a) match an independent + hand-rolled windowed reference, (b) reduce to the full exact gradient in the + degenerate single-window case, and (c) be invariant to the nested memory + knob ``block_size``. A network-level case checks the wiring through + ``prepare`` / ``run_scan`` and the ``BoundedSolver`` delegation. + + The combinator is tested with a toy linear recurrence so the reference is + independent and fully controlled, mirroring ``TestCheckpointedScan`` for the + network-level checks. + """ + + N = 20 # toy rollout length + C0 = jnp.asarray(0.3) + + @staticmethod + def _make_op(a): + # Toy step: carry' = a * carry + x, output = carry. ``a`` is the + # closed-over parameter we differentiate (stands in for theta). + def op(c, x): + nc = a * c + x + return nc, nc + + return op + + def _xs(self): + return jnp.arange(self.N, dtype=float) + 1.0 + + def _full(self, a): + _, ys = jax.lax.scan(self._make_op(a), self.C0, self._xs()) + return jnp.sum(ys) + + def _trunc(self, a, window, block_size): + from tvboptim.experimental.network_dynamics.solve import _truncated_scan + + _, ys = _truncated_scan( + self._make_op(a), self.C0, self._xs(), self.N, window, block_size + ) + return jnp.sum(ys) + + def _ref_trunc(self, a, window): + # Independent reference: an unrolled Python loop over windows that + # severs the carry gradient at each window entry, exactly the truncated + # estimator. No reshape/tail machinery, so it cannot share a bug with + # `_truncated_scan`. + op = self._make_op(a) + xs = self._xs() + c = self.C0 + outs = [] + i = 0 + while i < self.N: + j = min(i + window, self.N) + c = jax.lax.stop_gradient(c) + c, block = jax.lax.scan(op, c, xs[i:j]) + outs.append(block) + i = j + return jnp.sum(jnp.concatenate(outs, axis=0)) + + def test_forward_bitexact(self): + """Truncation does not touch the forward value, for divisor, + non-divisor and degenerate window sizes.""" + for window in (5, 7, self.N + 3): + with self.subTest(window=window): + self.assertTrue( + jnp.array_equal(self._full(0.5), self._trunc(0.5, window, None)) + ) + + def test_gradient_matches_reference(self): + """Truncated gradient equals the unrolled windowed reference, and is a + genuine truncation (differs from the full gradient).""" + g_full = jax.grad(self._full)(0.5) + for window in (5, 7): + with self.subTest(window=window): + g_trunc = jax.grad(lambda a: self._trunc(a, window, None))(0.5) + g_ref = jax.grad(lambda a: self._ref_trunc(a, window))(0.5) + self.assertTrue( + jnp.allclose(g_trunc, g_ref, rtol=1e-10, atol=1e-12), + f"window={window}: {g_trunc} vs ref {g_ref}", + ) + # The truncation is real: short windows drop cross-window credit. + self.assertFalse(jnp.allclose(g_trunc, g_full, rtol=1e-6)) + + def test_degenerate_window_equals_full(self): + """A single window (window >= n_steps) recovers the full exact gradient: + severing the leaf initial carry does not affect the parameter gradient.""" + g_full = jax.grad(self._full)(0.5) + for window in (self.N, self.N + 5): + with self.subTest(window=window): + g = jax.grad(lambda a: self._trunc(a, window, None))(0.5) + self.assertTrue(jnp.allclose(g_full, g, rtol=1e-10, atol=1e-12)) + + def test_gradient_invariant_to_block_size(self): + """Within a fixed window, subdividing into checkpoint sub-blocks + rematerializes activations but must not change the gradient, including + non-divisor sub-blocks and a sub-block larger than the window.""" + window = 10 + g_none = jax.grad(lambda a: self._trunc(a, window, None))(0.5) + for ce in (5, 3, 25): # divisor, non-divisor, larger-than-window + with self.subTest(block_size=ce): + g = jax.grad(lambda a: self._trunc(a, window, ce))(0.5) + self.assertTrue( + jnp.allclose(g_none, g, rtol=1e-10, atol=1e-12), + f"block_size={ce}: {g} vs {g_none}", + ) + + def _build_sde_network(self): + key = jax.random.PRNGKey(11) + wkey, dkey = jax.random.split(key) + n_nodes = 4 + weights = jax.random.uniform(wkey, (n_nodes, n_nodes)) * 0.5 + delays = jax.random.uniform(dkey, (n_nodes, n_nodes)) * 5.0 + graph = DenseDelayGraph(weights=weights, delays=delays) + coupling = DelayedLinearCoupling(incoming_states="S", G=0.1) + return Network( + dynamics=ReducedWongWang(), + coupling={"delayed": coupling}, + graph=graph, + noise=AdditiveNoise(sigma=1e-4, key=jax.random.key(0)), + ) + + def test_network_forward_bitexact(self): + """Through prepare/run_scan, the truncated path's forward trajectory is + bit-exact to the non-truncated path (truncation changes only gradients), + for a divisor and a non-divisor window of the 200-step rollout.""" + network = self._build_sde_network() + solve_full, cfg = prepare(network, Heun(), t0=0.0, t1=20.0, dt=0.1) + r_full = solve_full(cfg) + for window in (20, 13): # 200 % 20 == 0, 200 % 13 == 5 + with self.subTest(window=window): + solve_trunc, _ = prepare( + network, Heun(grad_horizon=window), t0=0.0, t1=20.0, dt=0.1 + ) + self.assertTrue(jnp.array_equal(r_full.ys, solve_trunc(cfg).ys)) + + def test_default_and_bounded_delegation(self): + """Default constructor leaves truncation off; BoundedSolver forwards the + knob from its base solver.""" + self.assertIsNone(Heun().grad_horizon) + self.assertEqual(Heun(grad_horizon=50).grad_horizon, 50) + self.assertEqual( + BoundedSolver(Heun(grad_horizon=50), low=0.0, high=1.0).grad_horizon, + 50, + ) + + +class TestReduce(unittest.TestCase): + """Verify the block-level reduce (fold) path. + + The fold folds each block's stacked outputs into an accumulator carried in + the scan instead of stacking the whole trajectory. Checked at two levels: a + toy running-sum reducer through ``run_scan`` (the independent reference is + the sum over the plainly-stacked trajectory), and a network-level online + ``welford_cov`` FC pinned against the post-hoc ``compute_fc``. The fold must + match for divisor / non-divisor / degenerate blocks, compose with + ``grad_horizon``, and be invariant (value and gradient) to ``block_size``. + """ + + N = 20 # toy rollout length + C0 = jnp.asarray(0.3) + + @staticmethod + def _make_op(a): + # Toy step: carry' = a * carry + x, output = carry (matches + # TestTruncatedScan so the reference is independent and controlled). + def op(c, x): + nc = a * c + x + return nc, nc + + return op + + def _xs(self): + return jnp.arange(self.N, dtype=float) + 1.0 + + # Toy reducer update: acc is the scalar running sum of all step outputs. + _UPDATE = staticmethod(lambda acc, block: acc + jnp.sum(block)) + + def _stacked_sum(self, a): + _, ys = jax.lax.scan(self._make_op(a), self.C0, self._xs()) + return jnp.sum(ys) + + def _fold_sum(self, a, block_size, window=None): + from tvboptim.experimental.network_dynamics.solve import run_scan + + solver = Heun(block_size=block_size, grad_horizon=window) + carry, _ = run_scan( + self._make_op(a), + self.C0, + self._xs(), + self.N, + solver, + fold=(jnp.asarray(0.0), self._UPDATE), + ) + return carry[1] # the accumulator threaded in the (state, acc) carry + + def test_fold_equals_stacked_sum(self): + """Folded accumulator equals the sum over the plainly-stacked + trajectory, for divisor / non-divisor / degenerate blocks, with and + without a truncation window (the forward value is identical).""" + ref = self._stacked_sum(0.5) + for block_size in (5, 7, self.N, 4): + # Windows are None or multiples of block_size (so no snapping fires; + # snapping is covered in TestStreamingNoise). + for window in (None, 2 * block_size, 3 * block_size): + with self.subTest(block_size=block_size, window=window): + got = self._fold_sum(0.5, block_size, window) + self.assertTrue( + jnp.allclose(got, ref, rtol=1e-12, atol=1e-12), + f"bs={block_size}, w={window}: {got} vs {ref}", + ) + + def test_fold_value_and_grad_invariant_to_block_size(self): + """Without truncation the fold is the exact gradient (checkpoint + rematerialization), so both value and gradient are invariant to + ``block_size``, including non-divisor and larger-than-rollout blocks.""" + v_ref = self._stacked_sum(0.5) + g_ref = jax.grad(self._stacked_sum)(0.5) + for block_size in (5, 7, 4, self.N): + with self.subTest(block_size=block_size): + v = self._fold_sum(0.5, block_size, None) + g = jax.grad(lambda a: self._fold_sum(a, block_size, None))(0.5) + self.assertTrue(jnp.allclose(v, v_ref, rtol=1e-12, atol=1e-12)) + self.assertTrue( + jnp.allclose(g, g_ref, rtol=1e-10, atol=1e-12), + f"bs={block_size}: grad {g} vs {g_ref}", + ) + + def _build_sde_network(self): + key = jax.random.PRNGKey(11) + wkey, dkey = jax.random.split(key) + n_nodes = 5 + weights = jax.random.uniform(wkey, (n_nodes, n_nodes)) * 0.5 + delays = jax.random.uniform(dkey, (n_nodes, n_nodes)) * 5.0 + graph = DenseDelayGraph(weights=weights, delays=delays) + coupling = DelayedLinearCoupling(incoming_states="S", G=0.1) + return Network( + dynamics=ReducedWongWang(), + coupling={"delayed": coupling}, + graph=graph, + noise=AdditiveNoise(sigma=1e-3, key=jax.random.key(0)), + ) + + def test_welford_matches_compute_fc(self): + """Online ``welford_cov`` over a blocked (streaming-noise) run equals the + post-hoc ``compute_fc`` on the matching streamed trajectory, for a + divisor and a non-divisor block size. The reference uses the SAME + ``block_size`` (hence the same per-block seeding), not the monolithic + global draw: blocked mode streams noise and reseeds, so the only valid + comparison is online-fold vs stack-then-post-hoc at matched seeding.""" + from tvboptim.observations import compute_fc, welford_cov + + net = self._build_sde_network() + for block_size in (50, 37): # 300 steps: divisor / non-divisor + with self.subTest(block_size=block_size): + # Stacked trajectory with this block_size (streamed noise) -> FC. + fc_ref = compute_fc( + solve(net, Heun(block_size=block_size), t0=0.0, t1=30.0, dt=0.1), + s_var=0, + ) + # Online welford over the same streamed run. + fc = solve( + net, + Heun(block_size=block_size), + t0=0.0, + t1=30.0, + dt=0.1, + reduce=welford_cov(s_var=0), + ) + self.assertEqual(fc.shape, fc_ref.shape) + self.assertTrue( + jnp.allclose(fc, fc_ref, atol=1e-4), + f"bs={block_size}: max diff {jnp.max(jnp.abs(fc - fc_ref))}", + ) + + def test_welford_monolithic_equals_compute_fc(self): + """``block_size=None`` with a reducer folds the whole stacked trajectory + once (the degenerate single-block / post-hoc case) and equals + ``compute_fc``.""" + from tvboptim.observations import compute_fc, welford_cov + + net = self._build_sde_network() + solve_full, cfg = prepare(net, Heun(), t0=0.0, t1=30.0, dt=0.1) + fc_ref = compute_fc(solve_full(cfg), s_var=0) + solve_fc, _ = prepare( + net, Heun(), t0=0.0, t1=30.0, dt=0.1, reduce=welford_cov(s_var=0) + ) + self.assertTrue(jnp.allclose(solve_fc(cfg), fc_ref, atol=1e-4)) + + def test_welford_differentiable_and_tbptt_invariant(self): + """The online FC is differentiable wrt a coupling gain, and (for a fixed + ``block_size``, hence fixed noise) its forward value is invariant to + ``grad_horizon`` snapped to a multiple of ``block_size`` -- truncation + changes only the gradient, not the streamed realisation.""" + import equinox as eqx + + from tvboptim.observations import welford_cov + + net = self._build_sde_network() + + # Differentiable wrt G on a blocked streaming run. + solve_fc, cfg = prepare( + net, + Heun(block_size=50), + t0=0.0, + t1=30.0, + dt=0.1, + reduce=welford_cov(s_var=0), + ) + + def loss(G): + c = eqx.tree_at(lambda c: c.coupling.delayed.G, cfg, G) + return jnp.sum(solve_fc(c) ** 2) + + v, g = jax.value_and_grad(loss)(jnp.asarray(0.1)) + self.assertTrue(jnp.isfinite(g)) + + # Forward FC invariant to a truncation window (multiple of block_size). + fc_base = solve( + net, + Heun(block_size=50), + t0=0.0, + t1=30.0, + dt=0.1, + reduce=welford_cov(s_var=0), + ) + for window in (100, 150, 300): # multiples of block_size=50 + with self.subTest(window=window): + fc_w = solve( + net, + Heun(block_size=50, grad_horizon=window), + t0=0.0, + t1=30.0, + dt=0.1, + reduce=welford_cov(s_var=0), + ) + self.assertTrue(jnp.array_equal(fc_base, fc_w)) + + def test_diffrax_rejects_reduce(self): + """``reduce`` is native-only; the Diffrax dispatch raises a clear error + rather than a bare TypeError on an unexpected keyword.""" + import diffrax + + from tvboptim.experimental.network_dynamics.coupling import LinearCoupling + from tvboptim.experimental.network_dynamics.solvers.diffrax import ( + DiffraxSolver, + ) + from tvboptim.observations import welford_cov + + n_nodes = 3 + net = Network( + dynamics=ReducedWongWang(), + coupling={"instant": LinearCoupling(incoming_states="S", G=0.1)}, + graph=DenseGraph( + jax.random.uniform(jax.random.PRNGKey(1), (n_nodes, n_nodes)) + ), + ) + with self.assertRaises(ValueError): + prepare( + net, + DiffraxSolver(solver=diffrax.Heun()), + t0=0.0, + t1=5.0, + dt=0.1, + reduce=welford_cov(), + ) + + +class TestStreamingNoise(unittest.TestCase): + """Per-block streaming noise (``fold_in``) under ``block_size``. + + Streaming activates for an SDE network with ``block_size`` set and no + injected tensor. The realisation is a pure function of + ``(key, absolute_block_idx)`` and the block grain, so it is deterministic, + invariant to the truncation window, and matches an independent reference + that folds the noise in the same way. It deliberately reseeds relative to + the monolithic global draw (documented). + """ + + T1 = 30.0 + DT = 0.1 # 300 steps + + def _build(self): + key = jax.random.PRNGKey(11) + wkey, dkey = jax.random.split(key) + n_nodes = 5 + weights = jax.random.uniform(wkey, (n_nodes, n_nodes)) * 0.5 + delays = jax.random.uniform(dkey, (n_nodes, n_nodes)) * 5.0 + graph = DenseDelayGraph(weights=weights, delays=delays) + coupling = DelayedLinearCoupling(incoming_states="S", G=0.1) + return Network( + dynamics=ReducedWongWang(), + coupling={"delayed": coupling}, + graph=graph, + noise=AdditiveNoise(sigma=1e-3, key=jax.random.key(0)), + ) + + def test_reseed_and_determinism(self): + """Blocked streaming reseeds vs monolithic; is deterministic for a fixed + block_size; and a different block grain gives a different realisation.""" + net = self._build() + kw = dict(t0=0.0, t1=self.T1, dt=self.DT) + mono = solve(net, Heun(), **kw).ys + a = solve(net, Heun(block_size=50), **kw).ys + b = solve(net, Heun(block_size=50), **kw).ys + c = solve(net, Heun(block_size=37), **kw).ys + self.assertFalse(jnp.allclose(mono, a)) # reseed vs global draw + self.assertTrue(jnp.array_equal(a, b)) # deterministic + self.assertFalse(jnp.allclose(a, c)) # block grain changes realisation + + def test_matches_matched_seeding_reference(self): + """Streaming forward equals an independent reference that builds the full + noise tensor by the same per-block ``fold_in`` and injects it into a + monolithic run, for a divisor and a non-divisor block size (exercising + the tail block).""" + net = self._build() + n_steps = len(jnp.arange(0.0, self.T1, self.DT)) + n_noise = len(net.noise._state_indices) + n_nodes = net.graph.n_nodes + key = net.noise.key + for block_size in (50, 37): + with self.subTest(block_size=block_size): + strm = solve( + net, Heun(block_size=block_size), t0=0.0, t1=self.T1, dt=self.DT + ).ys + # Independent reference: per-block fold_in chunks concatenated. + n_blocks = n_steps // block_size + rem = n_steps - n_blocks * block_size + chunks = [ + jax.random.normal( + jax.random.fold_in(key, i), (block_size, n_noise, n_nodes) + ) + for i in range(n_blocks) + ] + if rem: + chunks.append( + jax.random.normal( + jax.random.fold_in(key, n_blocks), (rem, n_noise, n_nodes) + ) + ) + full_noise = jnp.concatenate(chunks, axis=0) + solve_fn, cfg = prepare(net, Heun(), t0=0.0, t1=self.T1, dt=self.DT) + cfg._internal.noise_samples = full_noise + ref = solve_fn(cfg).ys + self.assertTrue( + jnp.allclose(strm, ref, atol=1e-5), + f"bs={block_size}: max diff {jnp.max(jnp.abs(strm - ref))}", + ) + + def test_forward_invariant_to_grad_horizon(self): + """For a fixed block_size the streamed forward trajectory is bit-exact + across truncation windows (multiples of block_size) -- the absolute + block grid does not depend on how windows tile it.""" + net = self._build() + kw = dict(t0=0.0, t1=self.T1, dt=self.DT) + base = solve(net, Heun(block_size=50), **kw).ys + for window in (100, 150, 300): + with self.subTest(window=window): + ys = solve(net, Heun(block_size=50, grad_horizon=window), **kw).ys + self.assertTrue(jnp.array_equal(base, ys)) + + def test_non_multiple_window_snaps_with_warning(self): + """A grad_horizon that is not a multiple of block_size is snapped to the + nearest multiple with a warning, rather than silently accepted.""" + net = self._build() + with self.assertWarns(UserWarning): + solve( + net, + Heun(block_size=50, grad_horizon=120), + t0=0.0, + t1=self.T1, + dt=self.DT, + ) + + def test_statistical_sanity(self): + """The streamed increments are standard normal: mean ~ 0, variance ~ 1 + across blocks.""" + from tvboptim.experimental.network_dynamics.solve import _streaming_noise_gen + + gen = _streaming_noise_gen(jax.random.key(0), (2, 8)) + sample = jnp.concatenate([gen(i, 64) for i in range(60)], axis=0) + self.assertLess(abs(float(jnp.mean(sample))), 0.02) + self.assertLess(abs(float(jnp.var(sample)) - 1.0), 0.05) + + if __name__ == "__main__": unittest.main() diff --git a/tests/test_observations/test_bold_monitors.py b/tests/test_observations/test_bold_monitors.py index f1d73c2..8d28a1a 100644 --- a/tests/test_observations/test_bold_monitors.py +++ b/tests/test_observations/test_bold_monitors.py @@ -172,5 +172,101 @@ def test_bold_alias_warns(self): self.assertIsInstance(b, HRFBold) +class TestStreamingHrfBold(unittest.TestCase): + """The block-level streaming HRF-BOLD reducer matches the post-hoc HRFBold. + + Streaming requires a SubSampling downsample (uniform, streamable) and a + block_size / n_steps that are multiples of the BOLD period in raw steps. + Because a blocked SDE run streams (reseeds) its noise, the equivalence + reference is the post-hoc monitor applied to the SAME streamed trajectory + (matched per-block seeding). + """ + + def _net(self): + import jax + + from tvboptim.experimental.network_dynamics import Network + from tvboptim.experimental.network_dynamics.coupling import ( + DelayedLinearCoupling, + ) + from tvboptim.experimental.network_dynamics.dynamics.tvb import ( + ReducedWongWang, + ) + from tvboptim.experimental.network_dynamics.graph import DenseDelayGraph + from tvboptim.experimental.network_dynamics.noise import AdditiveNoise + + k = jax.random.PRNGKey(7) + wk, dk = jax.random.split(k) + n = 4 + w = jax.random.uniform(wk, (n, n)) * 0.5 + d = jax.random.uniform(dk, (n, n)) * 5.0 + return Network( + dynamics=ReducedWongWang(), + coupling={"delayed": DelayedLinearCoupling(incoming_states="S", G=0.1)}, + graph=DenseDelayGraph(weights=w, delays=d), + noise=AdditiveNoise(sigma=1e-3, key=jax.random.key(0)), + ) + + def test_matches_posthoc_hrfbold(self): + from tvboptim.experimental.network_dynamics import solve + from tvboptim.experimental.network_dynamics.solvers import Heun + from tvboptim.observations.tvb_monitors import ( + HRFBold, + SubSampling, + streaming_hrf_bold, + ) + + net = self._net() + dt = 0.1 + # period/dt = (200/40)*(40/0.1) = 5*400 = 2000 raw steps per block. + mon = HRFBold( + period=200.0, + downsample_period=40.0, + downsample=SubSampling(period=40.0), + ) + t1 = 800.0 # 8000 steps, a multiple of 2000 + # Streaming reducer. + bold = solve( + net, + Heun(block_size=2000), + t0=0.0, + t1=t1, + dt=dt, + reduce=streaming_hrf_bold(mon, dt), + ) + # Post-hoc on the SAME streamed trajectory (matched per-block seeding). + ref = mon(solve(net, Heun(block_size=2000), t0=0.0, t1=t1, dt=dt)) + self.assertEqual(bold.shape, ref.ys.shape) + self.assertTrue( + jnp.allclose(bold, ref.ys, atol=1e-5), + f"max diff {jnp.max(jnp.abs(bold - ref.ys))}", + ) + + def test_misaligned_block_size_rejected(self): + from tvboptim.experimental.network_dynamics import solve + from tvboptim.experimental.network_dynamics.solvers import Heun + from tvboptim.observations.tvb_monitors import ( + HRFBold, + SubSampling, + streaming_hrf_bold, + ) + + net = self._net() + dt = 0.1 + mon = HRFBold( + period=200.0, downsample_period=40.0, downsample=SubSampling(period=40.0) + ) + # block_size=1500 is not a multiple of period/dt=2000 -> assert at trace. + with self.assertRaises(AssertionError): + solve( + net, + Heun(block_size=1500), + t0=0.0, + t1=600.0, + dt=dt, + reduce=streaming_hrf_bold(mon, dt), + ) + + if __name__ == "__main__": unittest.main() diff --git a/tests/test_observations/test_hrf_kernels.py b/tests/test_observations/test_hrf_kernels.py index c5e3394..a247bb0 100644 --- a/tests/test_observations/test_hrf_kernels.py +++ b/tests/test_observations/test_hrf_kernels.py @@ -84,4 +84,4 @@ def test_jittable(kernel): eager = kernel(t, DT) jitted = jax.jit(lambda tt: kernel(tt, DT))(t) peak = jnp.max(jnp.abs(eager)) - assert jnp.allclose(eager, jitted, atol=1e-5 * peak, rtol=1e-4) \ No newline at end of file + assert jnp.allclose(eager, jitted, atol=1e-5 * peak, rtol=1e-4) diff --git a/tests/test_observations/test_observation.py b/tests/test_observations/test_observation.py index c0b3e80..397ddad 100644 --- a/tests/test_observations/test_observation.py +++ b/tests/test_observations/test_observation.py @@ -140,11 +140,13 @@ def test_density_is_non_negative(self): class TestDistributionDistances(unittest.TestCase): def setUp(self): self.x = jnp.linspace(-1.0, 1.0, 200) - self.p = jnp.exp(-(self.x ** 2) / 0.1) + self.p = jnp.exp(-(self.x**2) / 0.1) self.q = jnp.exp(-((self.x - 0.3) ** 2) / 0.1) def test_w1_self_distance_is_zero(self): - self.assertAlmostEqual(float(wasserstein_1d(self.p, self.p, self.x)), 0.0, places=6) + self.assertAlmostEqual( + float(wasserstein_1d(self.p, self.p, self.x)), 0.0, places=6 + ) def test_w1_symmetric(self): self.assertAlmostEqual(