binary search for batch size#362
Conversation
Refine auto batch size with binary search after OOM. After the exponential probe hits CUDA OOM, binary-search between the last successful and first failed size to find higher-throughput batch sizes, still picking the size with the best measured tokens/s.
There was a problem hiding this comment.
Code Review
This pull request refactors the batch size determination logic into helper functions and introduces a binary search refinement mechanism to find the optimal batch size when an out-of-memory (OOM) error is encountered. Additionally, the default max_batch_size is increased from 128 to 1024. Feedback suggests using torch.OutOfMemoryError instead of torch.cuda.OutOfMemoryError to make OOM detection more robust across different accelerators, and adding defensive guards in _determine_batch_size to handle edge cases such as empty prompts or invalid maximum batch sizes.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
|
Hey Noah, verified this on my GPU and it works great! I added the fixes for the code review comments (generalizing to OutOfMemoryError and adding defensive guards in _determine_batch_size). Feel free to pull them from my branch: |
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
|
The current approach is not ideal, but a binary search is not the right solution either IMO, at least not in this form. Here's the crux of the problem: We aren't actually interested in finding the largest possible batch size. There is another constraint, and that is the size of the prompt dataset(s). By default, the evaluation datasets contain 100 prompts each. That means choosing any batch size greater than 100 doesn't make sense, because we'll never run more than 100 prompts anyway. In fact, the only batch sizes that we should consider are The correct search strategy is to start with a batch size of |
When
batch_size = 0(auto), Heretic probes batch sizes exponentially and picks the size with the best measured throughput (tokens/s). Previously, if a probe hit CUDA OOM, the search stopped and kept the last successful power-of-two even though when a larger non-power-of-two batch would fit and run faster. I also saw #248 but i think this is a good solution for now.Refines the search after OOM binary-searches between the last successful size and the first failing size, still choosing the batch with the highest measured tokens/s.