feat: Implement KV-cache for umT5 (#137)#162
feat: Implement KV-cache for umT5 (#137)#162Satyamgupta2365 wants to merge 1 commit intojax-ml:mainfrom
Conversation
Summary of ChangesHello @Satyamgupta2365, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
|
Hi @jenriver and @chapman20j, Thank you for the detailed feedback. I’ll remove the extra artifacts, refactor the cache to align with existing Bonsai patterns, ensure pretrained weight loading works correctly, and update the tests with proper numerical comparisons. This is the new PR wit the updated code. |
There was a problem hiding this comment.
Code Review
This pull request introduces a KV-cache for the umT5 model to accelerate autoregressive inference, which is a significant performance enhancement. The implementation is generally well-structured, including the new LayerCache class and the necessary modifications to the attention and generation logic. However, I've identified a critical issue in the self-attention caching mechanism that could lead to incorrect outputs and inefficiency. My review includes a specific code suggestion to address this.
| k = cache.k_cache[...] | ||
| v = cache.v_cache[...] | ||
| cache.cur_ind.value = cache.cur_ind.value + q_len |
There was a problem hiding this comment.
The current implementation retrieves the full cache tensors for keys and values, but doesn't mask the unused padded portions. This can lead to incorrect attention scores and is inefficient as computation is performed over padding, because the softmax will be computed over zero-vectors. To fix this, you should slice the cache to its valid length after updating it.
| k = cache.k_cache[...] | |
| v = cache.v_cache[...] | |
| cache.cur_ind.value = cache.cur_ind.value + q_len | |
| new_len = cache.cur_ind.value + q_len | |
| k = jax.lax.dynamic_slice(cache.k_cache[...], (0, 0, 0, 0), (k.shape[0], new_len, k.shape[2], k.shape[3])) | |
| v = jax.lax.dynamic_slice(cache.v_cache[...], (0, 0, 0, 0), (v.shape[0], new_len, v.shape[2], v.shape[3])) | |
| cache.cur_ind.value = new_len |
|
@Satyamgupta2365 Make sure to use |
Summary
Implements KV-cache for umT5 model to enable efficient autoregressive inference.
Changes
use_cacheparameter (default: True)Performance
Files Modified
Closes #137