Skip to content

binary search for batch size#362

Open
NoahOksuz wants to merge 4 commits into
p-e-w:masterfrom
NoahOksuz:optibatchsize
Open

binary search for batch size#362
NoahOksuz wants to merge 4 commits into
p-e-w:masterfrom
NoahOksuz:optibatchsize

Conversation

@NoahOksuz

Copy link
Copy Markdown

When batch_size = 0 (auto), Heretic probes batch sizes exponentially and picks the size with the best measured throughput (tokens/s). Previously, if a probe hit CUDA OOM, the search stopped and kept the last successful power-of-two even though when a larger non-power-of-two batch would fit and run faster. I also saw #248 but i think this is a good solution for now.

Refines the search after OOM binary-searches between the last successful size and the first failing size, still choosing the batch with the highest measured tokens/s.

Determining optimal batch size...
* Trying batch size 1... Ok (20 tokens/s)
* Trying batch size 2... Ok (38 tokens/s)
* blah blah
* Trying batch size 128... Ok (814 tokens/s)
* Trying batch size 256... Failed (CUDA out of memory...)
* Trying batch size 192... Ok (902 tokens/s)
* Trying batch size 224... Ok (905 tokens/s)
* Trying batch size 240... Ok (914 tokens/s)
* Trying batch size 248... Failed (CUDA out of memory...)
* Chosen batch size: 240

NoahOksuz added 2 commits June 7, 2026 13:38
Refine auto batch size with binary search after OOM.

After the exponential probe hits CUDA OOM, binary-search between the last
successful and first failed size to find higher-throughput batch sizes,
still picking the size with the best measured tokens/s.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the batch size determination logic into helper functions and introduces a binary search refinement mechanism to find the optimal batch size when an out-of-memory (OOM) error is encountered. Additionally, the default max_batch_size is increased from 128 to 1024. Feedback suggests using torch.OutOfMemoryError instead of torch.cuda.OutOfMemoryError to make OOM detection more robust across different accelerators, and adding defensive guards in _determine_batch_size to handle edge cases such as empty prompts or invalid maximum batch sizes.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread src/heretic/main.py
Comment thread src/heretic/main.py
@umran666

umran666 commented Jun 7, 2026

Copy link
Copy Markdown
Contributor

Hey Noah, verified this on my GPU and it works great! I added the fixes for the code review comments (generalizing to OutOfMemoryError and adding defensive guards in _determine_batch_size). Feel free to pull them from my branch: https://github.com/umran666/heretic/tree/fix/binary-search-batch-size

NoahOksuz and others added 2 commits June 7, 2026 16:53
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
@p-e-w

p-e-w commented Jun 8, 2026

Copy link
Copy Markdown
Owner

The current approach is not ideal, but a binary search is not the right solution either IMO, at least not in this form. Here's the crux of the problem:

We aren't actually interested in finding the largest possible batch size. There is another constraint, and that is the size of the prompt dataset(s).

By default, the evaluation datasets contain 100 prompts each. That means choosing any batch size greater than 100 doesn't make sense, because we'll never run more than 100 prompts anyway.

In fact, the only batch sizes that we should consider are ceil(len(prompts) / n) for n = 1, 2, 3, .... Choosing any other size will always be dominated by the next smaller member of this set, because you need the same number of batches but are less robust against VRAM fluctuations.

The correct search strategy is to start with a batch size of len(prompts), then proceed with len(prompts) / 2 if there is an OOM, followed by len(prompts) / 3 etc. But as noted in #248, this needs to happen on each call to generate (with caching), because otherwise, the locked-in batch size from the start of the run leads to a crash if some other process consumes VRAM in between.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants