Skip to content

Overlap data generation with GPU training via async pipeline#11

Open
j-vaught wants to merge 1 commit intoforestagostinelli:mainfrom
j-vaught:async-pipeline
Open

Overlap data generation with GPU training via async pipeline#11
j-vaught wants to merge 1 commit intoforestagostinelli:mainfrom
j-vaught:async-pipeline

Conversation

@j-vaught
Copy link
Copy Markdown

@j-vaught j-vaught commented Apr 1, 2026

Hey Dr. Agostinelli,

Just proposing a small change to speed up the library a tad.

Summary

  • Overlap next-round data generation with current-round GPU training using a background threading.Thread and a second DataBuffer
  • First iteration is synchronous (cold start); all subsequent iterations prefetch data concurrently with training
  • No changes to CLI, APIs, sync_main path, or training algorithm

More info

In the default (non-sync_main) training path, update_step() blocks on _get_update_data() until all worker data arrives, then trains sequentially. The GPU is idle during data generation (~7s) and the CPU is idle during training (~2.5s). This change overlaps the two phases by starting the next round's data collection in a background thread before calling _train().

Benchmark

Tested on RTX 6000 Ada Generation (48 GB), Cube3 domain, resnet_fc.5000H_4B_bn, 5000 iterations, 3 runs per configuration.

Baseline Async Speedup
Run 1 10m 03s 6m 15s 1.61x
Run 2 10m 08s 6m 16s 1.62x
Run 3 10m 05s 6m 19s 1.60x
Mean 10m 05s 6m 17s 1.61x

Per-update time drops from ~11.4s to ~6.8s in steady state. Training convergence (loss, solve rate) is unaffected -- differences are within run-to-run variance.

How it works

  1. After _end_update(N) completes, call start_update(N+1) and launch a prefetch thread
  2. The prefetch thread calls updater.get_update_data() (which blocks on from_q.get(), releasing the GIL)
  3. Meanwhile, the main thread runs _train() (PyTorch CUDA ops, also GIL-free)
  4. On the next update_step() call, wait for the prefetch thread, swap buffers, and repeat

The end_update(N) / start_update(N+1) ordering constraint is preserved and both run on the main thread before training begins.

Add a background threading.Thread that collects next-round data into
a second DataBuffer while the main thread trains on the current round.
First iteration is synchronous (cold start); subsequent iterations
prefetch concurrently. sync_main path is unchanged.

Benchmarked on RTX 6000 Ada (Cube3, 5000 iterations, 3 runs):
  Baseline mean: 10m 05s
  Async mean:     6m 17s  (1.61x speedup)

Training convergence unaffected. All code paths regression tested.
@j-vaught
Copy link
Copy Markdown
Author

j-vaught commented Apr 1, 2026

Seems to have a new library imported; need to add that to correct places.

@j-vaught
Copy link
Copy Markdown
Author

j-vaught commented Apr 1, 2026

Seems to have a new library imported; need to add that to correct places.

Neverminf. Part of Python default package.

@j-vaught j-vaught marked this pull request as ready for review April 1, 2026 00:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant