Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,7 @@ docs/_styles-quartodoc.css
Thumbs.db

# AI Skills
src/tvboptim/claude/skills
src/tvboptim/claude/skills

# Stray cache dir from running doc notebooks outside their own directory
/cache/
11 changes: 11 additions & 0 deletions docs/_quarto.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ website:
- text: "Workflows"
menu:
- file: workflows/RWW.qmd
- file: workflows/TBPTT.qmd
- file: workflows/JR.qmd
- file: workflows/EI_Tuning.qmd
- file: workflows/Hopf_Pareto_ParallelOpt.qmd
Expand All @@ -42,6 +43,7 @@ website:
menu:
- file: advanced/buffer_strategies.qmd
- file: advanced/gradient_checkpointing.qmd
- file: advanced/streaming_reductions.qmd
- file: advanced/coupling_freezing.qmd
- file: advanced/subspace_coupling.qmd
- text: "Reference"
Expand Down Expand Up @@ -204,3 +206,12 @@ quartodoc:
desc: The unified interface for preparing experiments and networks for simulation
contents:
- tvbo.prepare

- title: Analysis
desc: Characterizing the dynamical regime
package: tvboptim.experimental.network_dynamics.analysis
contents:
- lyapunov
- lyapunov_spectrum
- adiabatic_scan
- AdiabaticScanResult
Binary file not shown.
Binary file not shown.
140 changes: 80 additions & 60 deletions docs/advanced/gradient_checkpointing.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,13 @@ except ImportError:
The standard remedy is **gradient checkpointing**: instead of saving every
step's activations for the backward pass, save only a sparse subset and
recompute the missing ones on demand. TVB-Optim implements this for the native
solver path as a single optional knob on `NativeSolver`:
solver path as a single optional knob on `NativeSolver`, `block_size`:

```python
solver = Heun(checkpoint_every=256)
solver = Heun(block_size=256)
```

When `checkpoint_every` is `None` (the default) the integration runs as a
When `block_size` is `None` (the default) the integration runs as a
single `jax.lax.scan` exactly as before — there is no overhead and no behaviour
change. When set to an integer `K`, the scan is split into an outer scan over
blocks of `K` steps wrapped in `jax.checkpoint`, with an inner scan running
Expand All @@ -66,17 +66,34 @@ overhead — typically 1.3–1.7× depending on workload**, since backward is
usually already several times more expensive than forward and one extra
forward pass adds only a fraction to that total. Forward time is unchanged. The optimum for memory minimisation lies near `K ≈ √n_steps`.

`block_size` is the one block unit for the solver's streaming features (it was
formerly named `checkpoint_every`). Two consequences matter for this notebook:

- On a **stochastic** network, setting `block_size` also switches the noise to
per-block generation, which *reseeds* the realization relative to the
monolithic single draw. To keep this a clean checkpointing benchmark (common
random numbers, bit-exact across `block_size`), we inject one fixed noise
tensor so the block path uses it verbatim instead of streaming. The
streaming-noise memory saving is a separate axis; see
[Streaming Reductions](streaming_reductions.qmd).
- The same `block_size` is also the grain for online `reduce` statistics (e.g.
streamed FC), not covered here.

::: {.callout-note}
## Scope and limitations
- **Native solvers only.** `DiffraxSolver` is not affected; Diffrax exposes
its own `RecursiveCheckpointAdjoint` for adaptive ODE solves, but it does
not support delays.
- **No effect when `checkpoint_every is None`.** The call site falls through to
- **No effect when `block_size is None`.** The call site falls through to
the original `jax.lax.scan(op, state0, scan_inputs)` line. The default
behaviour is bit-exact with prior versions.
- **Forward is unaffected by memory savings.** Forward simulations do not retain
step activations regardless; checkpointing only matters when you take a
gradient.
- **SDE noise is held fixed here.** With a stochastic network, `block_size`
streams (reseeds) the noise; this benchmark injects a fixed noise tensor so
every config integrates the same path and the checkpointed gradient stays
bit-exact to the uncheckpointed one.
:::

```{python}
Expand Down Expand Up @@ -183,7 +200,7 @@ top of the dynamics state, the noise tensor, and the auxiliary tape.
## Benchmark

We benchmark forward time, gradient time, and (where the backend supports it)
peak device memory across a sweep of `checkpoint_every` values. The sweep
peak device memory across a sweep of `block_size` values. The sweep
covers:

- `None` — the default, single `jax.lax.scan`. Reference for performance.
Expand All @@ -204,11 +221,21 @@ covers:
# is a clean divisor of n_steps (no tail). Most other values do not divide
# n_steps exactly and therefore exercise the main-scan + tail-scan path,
# which matters for the memory story — see "Reading the memory curve".
CHECKPOINT_VALUES = [None, 32, 128, 256, 512, 1024, 2048, 8192, 30000, N_STEPS]
BLOCK_SIZE_VALUES = [None, 32, 128, 256, 512, 1024, 2048, 8192, 30000, N_STEPS]
N_FORWARD_RUNS = 3
N_GRADIENT_RUNS = 3
G_INIT = jnp.asarray(0.5)

# Fixed noise realization (common random numbers). Injecting this into the
# config makes `block_size` do pure gradient checkpointing rather than per-block
# streaming: every config integrates the same noise path, so the checkpointed
# gradient stays bit-exact to the uncheckpointed one and the benchmark isolates
# the activation-tape effect. Shape is [n_steps, n_noise_states, n_nodes].
n_noise_states = len(network.noise._state_indices)
FIXED_NOISE = jax.random.normal(
network.noise.key, (N_STEPS, n_noise_states, n_nodes)
)


class RSSPeakMonitor:
"""Context manager that records peak process RSS during the with-block.
Expand Down Expand Up @@ -276,11 +303,14 @@ class RSSPeakMonitor:
self._stop.wait(self.sample_interval)


def benchmark_one(checkpoint_every, fc_target):
def benchmark_one(block_size, fc_target):
"""Time forward + gradient, capture peak RSS during gradient, and return
the gradient value for cross-check."""
solver = Heun(checkpoint_every=checkpoint_every)
solver = Heun(block_size=block_size)
solve_fn, state = prepare(network, solver, t0=0.0, t1=T1, dt=DT)
# Inject the fixed noise so block_size does pure checkpointing (no per-block
# streaming / reseed); all configs then share the same realization.
state._internal.noise_samples = FIXED_NOISE
solve_fn = jax.jit(solve_fn)

def loss(G):
Expand Down Expand Up @@ -338,12 +368,12 @@ def benchmark_one(checkpoint_every, fc_target):
}


@cache("checkpoint_sweep", redo=False)
@cache("block_size_sweep", redo=True)
def run_sweep():
results = {}
for k in CHECKPOINT_VALUES:
for k in BLOCK_SIZE_VALUES:
label = "None" if k is None else str(k)
print(f"checkpoint_every = {label} ...", flush=True)
print(f"block_size = {label} ...", flush=True)
results[label] = benchmark_one(k, fc_target)
gc.collect()
return results
Expand All @@ -356,15 +386,15 @@ sweep_results = run_sweep()

```{python}
#| label: fig-checkpoint-benchmark
#| fig-cap: "**Gradient checkpointing benchmark.** Top row — *time*. Top-left: Forward and gradient wall time as a function of `checkpoint_every`; dotted horizontals mark the `None` baseline, dashed vertical marks `√n_steps`. Top-right: Per-call gradient-to-forward ratio. Bottom row — *memory*. Bottom-left: Peak RSS delta during a gradient call vs `checkpoint_every`, showing the `O(n_steps/K + K)` minimum near `√n_steps`. Bottom-right: Memory–time Pareto, with the `None` star at the low-time / high-memory extreme and checkpointed points tracing the front."
#| fig-cap: "**Gradient checkpointing benchmark.** Top row — *time*. Top-left: Forward and gradient wall time vs `block_size` on a shared log y-axis (the two sit about a decade apart); dashed horizontals mark each curve's `None` baseline, the dashed vertical marks `√n_steps`. Top-right: Per-call gradient-to-forward ratio. Bottom row — *memory*. Bottom-left: Peak RSS delta during a gradient call vs `block_size` (linear y), showing the `O(n_steps/K + K)` minimum near `√n_steps`. Bottom-right: Memory–time Pareto (front and `None` points labelled), with the `None` star at the low-time / high-memory extreme and checkpointed points tracing the front."
#| echo: true
#| code-fold: true
#| code-summary: "Plotting code"

baseline = sweep_results["None"]
sqrt_n = np.sqrt(N_STEPS)

# K-axis panels drop "None" — it has no x-coordinate on a checkpoint_every
# K-axis panels drop "None" — it has no x-coordinate on a block_size
# axis, only a horizontal-reference role. The Pareto panel keeps it as a
# distinct star marker because its axes are (time, memory) and there is no
# overlap risk.
Expand Down Expand Up @@ -430,47 +460,31 @@ with plt.rc_context({

# === Top row: time ===

# --- Top-left: time vs checkpoint_every (twin y-axes) ---
# Forward and gradient times are an order of magnitude apart, so a
# shared y-axis would collapse the forward curve to a flat line near
# the bottom. Twin axes let each curve fill its own range; we use
# linear scaling on both because each axis spans well under a decade.
# Axis spines and tick colours are tinted to indicate which curve goes
# with which side.
# --- Top-left: time vs block_size (single log y-axis) ---
# Forward (~0.1 s) and gradient (~1 s) are about a decade apart, so a single
# log y-axis separates them cleanly and keeps each None baseline next to its
# own curve without the two dashed references overlapping (which a linear
# twin-axis layout did). The fine overhead detail lives in the panel to the
# right (grad / forward ratio).
ax = axes[0, 0]
ax_g = ax.twinx()

fwd_color = "steelblue"
grad_color = "firebrick"

ax.errorbar(xs, fwd, yerr=fwd_err, marker="o", color=fwd_color,
label="forward", lw=1.8, markersize=7, capsize=3)
ax.axhline(baseline["fwd_mean"], color=fwd_color, linestyle="dashed",
alpha=0.7, label="forward (None)")

ax_g.errorbar(xs, grad, yerr=grad_err, marker="s", color=grad_color,
label="gradient", lw=1.8, markersize=7, capsize=3)
ax_g.axhline(baseline["grad_mean"], color=grad_color, linestyle="dashed",
alpha=0.7, label="gradient (None)")
ax.errorbar(xs, grad, yerr=grad_err, marker="s", color=grad_color,
label="gradient", lw=1.8, markersize=7, capsize=3)
ax.axhline(baseline["grad_mean"], color=grad_color, linestyle="dashed",
alpha=0.7, label="gradient (None)")

ax.set_xscale("log")
ax.set_xlabel("checkpoint_every")
ax.set_ylabel("forward wall time (s)", color=fwd_color)
ax_g.set_ylabel("gradient wall time (s)", color=grad_color)
ax.set_title("Time vs checkpoint_every")

# Tint the axis ticks and spines to match the data they describe.
ax.tick_params(axis="y", colors=fwd_color)
ax.spines["left"].set_color(fwd_color)
ax_g.tick_params(axis="y", colors=grad_color)
ax_g.spines["right"].set_color(grad_color)
ax_g.spines["left"].set_visible(False)

# Combined legend from both axes.
h1, l1 = ax.get_legend_handles_labels()
h2, l2 = ax_g.get_legend_handles_labels()
ax.legend(h1 + h2, l1 + l2, loc="best", framealpha=0.9)

ax.set_yscale("log")
ax.set_xlabel("block_size")
ax.set_ylabel("wall time (s)")
ax.set_title("Time vs block_size")
ax.legend(loc="best", framealpha=0.9, ncol=2)
ax.grid(alpha=0.3, which="both")
_mark_sqrt_n(ax)

Expand All @@ -483,7 +497,7 @@ with plt.rc_context({
ax.axhline(baseline_ratio, color="darkgreen", linestyle="dashed", alpha=0.7,
label=f"None baseline ({baseline_ratio:.2f}×)")
ax.set_xscale("log")
ax.set_xlabel("checkpoint_every")
ax.set_xlabel("block_size")
ax.set_ylabel("grad / forward")
ax.set_title("Gradient overhead")
ax.grid(alpha=0.3, which="both")
Expand All @@ -492,18 +506,19 @@ with plt.rc_context({

# === Bottom row: memory ===

# --- Bottom-left: memory vs checkpoint_every ---
# --- Bottom-left: memory vs block_size ---
ax = axes[1, 0]
if has_memory:
ax.plot(xs, mem_ck_mb, marker="D", color="purple", lw=1.8,
markersize=8, label="peak RSS delta during grad")
ax.axhline(baseline["peak_bytes_delta"] / 1e6, color="purple",
linestyle="dashed", alpha=0.7, label="None baseline")
ax.set_xscale("log")
ax.set_yscale("log")
ax.set_xlabel("checkpoint_every")
ax.set_xscale("log") # block_size spans decades; y stays linear so the
# U-shape and the absolute MB differences read directly.
ax.set_ylim(bottom=0)
ax.set_xlabel("block_size")
ax.set_ylabel("peak RSS delta during grad (MB)")
ax.set_title("Memory vs checkpoint_every")
ax.set_title("Memory vs block_size")
ax.grid(alpha=0.3, which="both")
ax.legend(loc="best", framealpha=0.9)
_mark_sqrt_n(ax)
Expand All @@ -515,7 +530,7 @@ with plt.rc_context({
bbox=dict(boxstyle="round,pad=0.5", facecolor="lightyellow"))
ax.set_xticks([])
ax.set_yticks([])
ax.set_title("Memory vs checkpoint_every (unavailable)")
ax.set_title("Memory vs block_size (unavailable)")

# --- Bottom-right: memory–time Pareto ---
# Two cleanup ideas vs the old "connect-by-time" line, which crossed
Expand Down Expand Up @@ -557,8 +572,13 @@ with plt.rc_context({
else:
ax.scatter([x], [y], s=55, facecolor="white",
edgecolor="purple", linewidth=1.3, zorder=2)
ax.annotate(l, (x, y), textcoords="offset points",
xytext=(8, 6), fontsize=11)
# Label only the front points and None: the dominated points cluster
# near the front and their labels collide. Alternate the vertical
# offset to further reduce overlap among the labelled ones.
if on_front or l == "None":
dy = 8 if (i % 2 == 0) else -12
ax.annotate(l, (x, y), textcoords="offset points",
xytext=(8, dy), fontsize=10)

ax.set_xlabel("gradient time (s)")
ax.set_ylabel("peak RSS delta during grad (MB)")
Expand Down Expand Up @@ -599,7 +619,7 @@ with plt.rc_context({

## Reading the Memory Curve

The memory-vs-`checkpoint_every` panel (bottom-left) shows the classical analysis. Peak gradient
The memory-vs-`block_size` panel (bottom-left) shows the classical analysis. Peak gradient
memory scales as

$$ \mathrm{peak\,memory} \;\approx\; \underbrace{\frac{n_\text{steps}}{K} \cdot c_\text{outer}}_{\text{block-boundary tape}} \;+\; \underbrace{K \cdot c_\text{inner}}_{\text{per-block inner tape during backward}} $$
Expand Down Expand Up @@ -673,7 +693,7 @@ from the saved tape.
#| echo: true

baseline = sweep_results["None"]
print(f"{'checkpoint_every':<20} {'loss':<22} {'grad':<22} {'|Δgrad / grad|':<18}")
print(f"{'block_size':<20} {'loss':<22} {'grad':<22} {'|Δgrad / grad|':<18}")
print("-" * 82)
for label, res in sweep_results.items():
rel = abs((res["grad_value"] - baseline["grad_value"]) / baseline["grad_value"])
Expand All @@ -700,7 +720,7 @@ or the device-memory delta if `jax.devices()[0].memory_stats()` is available

baseline = sweep_results["None"]
header = (
f"{'ckpt_every':<12} "
f"{'block_size':<12} "
f"{'fwd_s':<14} "
f"{'grad_s':<14} "
f"{'grad/fwd':<10} "
Expand Down Expand Up @@ -750,15 +770,15 @@ print(f"# device: {jax.devices()[0].platform} jax {jax.__version__}")

## No-Regression Check

Because `checkpoint_every=None` selects the original `jax.lax.scan` call site
Because `block_size=None` selects the original `jax.lax.scan` call site
verbatim (a literal if-branch), forward and gradient times for the default
must be **within timing noise** of the previous non-checkpointed implementation.
The benchmark above implicitly verifies this: the `None` row should be
statistically indistinguishable from any prior measurement of the unchecked
path. `K = n_steps` is **not** equivalent to `None` — it still
wraps the (single) inner scan in `jax.checkpoint`, so the backward pass
recomputes the entire forward once, costing roughly 1.3× the `None`
gradient time. Only `checkpoint_every=None` skips checkpointing entirely.
gradient time. Only `block_size=None` skips checkpointing entirely.

## Practical Guidance

Expand All @@ -770,11 +790,11 @@ from tvboptim.experimental.network_dynamics.solvers import Heun
solver = Heun()

# Memory-optimal default when gradients no longer fit in memory.
solver = Heun(checkpoint_every=int(math.sqrt(n_steps)))
solver = Heun(block_size=int(math.sqrt(n_steps)))

# Aggressive: minimal memory, maximal recompute. Use only if the sqrt
# default still OOMs.
solver = Heun(checkpoint_every=64)
solver = Heun(block_size=64)
```

The same field works on `Euler`, `Heun`, `RungeKutta4`, and any
Expand Down
Loading
Loading