Conversation
Summary of ChangesHello @chapman20j, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces foundational Key-Value (KV) caching utilities for transformer models within the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces KV caching utilities, including a standard LayerCache and a CyclicCache, along with a protocol they adhere to. While the initiative to create consistent caching utilities is great, this PR has some critical issues that need to be addressed before merging.
Most importantly, the associated test file bonsai/utils/tests/kv_cache_test.py contains only empty test stubs, despite the PR checklist indicating that tests have been added. The caching logic, especially for CyclicCache, is complex and includes author notes about needing verification. It is crucial to add comprehensive unit tests.
Additionally, there are several issues in the implementation:
- The
CyclicCache.updatemethod has a bug that can lead to out-of-bounds writes when an update wraps around the buffer. - The
prefillmethods in both cache classes don't correctly handlesegment_idsfor initialization, which will lead to runtime errors. - The
CacheProtocolhas a method signature mismatch with its implementations. - The module configures logging at import time, which is a bad practice for a library.
Please address these points, with the highest priority on adding tests.
| def update(self, k: Array, v: Array): | ||
| assert self.start_ind_initialized, "Must initialize start_ind before updating LayerCache" | ||
| slice_indices = (0, self.cur_ind[...] % self.cache_size, 0, 0) | ||
| self.k_cache[...] = jax.lax.dynamic_update_slice(self.k_cache[...], k, slice_indices) | ||
| self.v_cache[...] = jax.lax.dynamic_update_slice(self.v_cache[...], v, slice_indices) | ||
| self.cur_ind[...] = self.cur_ind[...] + k.shape[1] |
There was a problem hiding this comment.
The update method in CyclicCache has a potential out-of-bounds write issue. When k.shape[1] > 1 and the update wraps around the cache buffer, jax.lax.dynamic_update_slice will attempt to write past the end of the cache array. For example, if cache_size=4, cur_ind=3, and k.shape[1]=2, the update will start at index 3 and try to write 2 elements, accessing index 4 which is out of bounds.
While the TODO notes an assumption of single-token updates post-prefill, the code doesn't enforce it. To prevent incorrect behavior, you should either handle the wrap-around case (which would require splitting the update into two dynamic_update_slice calls) or add an assertion to enforce the assumption that updates do not wrap.
| def update(self, k: Array, v: Array): | |
| assert self.start_ind_initialized, "Must initialize start_ind before updating LayerCache" | |
| slice_indices = (0, self.cur_ind[...] % self.cache_size, 0, 0) | |
| self.k_cache[...] = jax.lax.dynamic_update_slice(self.k_cache[...], k, slice_indices) | |
| self.v_cache[...] = jax.lax.dynamic_update_slice(self.v_cache[...], v, slice_indices) | |
| self.cur_ind[...] = self.cur_ind[...] + k.shape[1] | |
| def update(self, k: Array, v: Array): | |
| assert self.start_ind_initialized, "Must initialize start_ind before updating LayerCache" | |
| assert self.cur_ind[...] % self.cache_size + k.shape[1] <= self.cache_size, "Update logic does not support wrapping around the buffer yet." | |
| slice_indices = (0, self.cur_ind[...] % self.cache_size, 0, 0) | |
| self.k_cache[...] = jax.lax.dynamic_update_slice(self.k_cache[...], k, slice_indices) | |
| self.v_cache[...] = jax.lax.dynamic_update_slice(self.v_cache[...], v, slice_indices) | |
| self.cur_ind[...] = self.cur_ind[...] + k.shape[1] |
| class TestLayerCache(absltest.TestCase): | ||
| def setUp(self): | ||
| super().setUp() | ||
|
|
||
| def test_init(self): | ||
| pass | ||
|
|
||
| def test_prefill(self): | ||
| pass | ||
|
|
||
| def test_update(self): | ||
| pass | ||
|
|
||
| def test_compute_causal_mask(self): | ||
| pass |
There was a problem hiding this comment.
The test cases for TestLayerCache are empty. The PR description states that unit tests have been added, but the implementations are missing. Given the complexity of KV caching logic, comprehensive unit tests are essential to ensure correctness. Please implement these tests, covering initialization, prefill, update, and causal mask computation.
| class TestCyclicCache(absltest.TestCase): | ||
| def setUp(self): | ||
| super().setUp() | ||
|
|
||
| def test_init(self): | ||
| pass | ||
|
|
||
| def test_prefill(self): | ||
| pass | ||
|
|
||
| def test_update(self): | ||
| pass | ||
|
|
||
| def test_compute_causal_mask(self): | ||
| pass | ||
|
|
||
| def test_update_after_cache_full(self): | ||
| pass |
There was a problem hiding this comment.
The test cases for TestCyclicCache are empty. The caching logic here is particularly complex, with wrap-around behavior and tricky causal masking, and the implementation file itself contains TODOs about verifying the logic. It is critical to add comprehensive unit tests covering prefill, single-token updates, multi-token updates (especially edge cases that cause buffer wrapping), and causal mask computation to ensure the implementation is correct.
| self.start_ind_initialized = True | ||
|
|
||
| def prefill(self, k_new: Array, v_new: Array, segment_ids: Array): | ||
| return self.update(k_new, v_new) |
There was a problem hiding this comment.
The prefill method in LayerCache accepts segment_ids but doesn't use them. It directly calls update, which relies on self.start_ind being initialized. If the cache is initialized without segment_ids, a call to prefill will fail the assertion inside update.
The prefill method should be responsible for initializing start_ind if it hasn't been done already.
| return self.update(k_new, v_new) | |
| if not self.start_ind_initialized: | |
| self._init_start_ind(segment_ids) | |
| return self.update(k_new, v_new) |
| assert k_new.shape[1] <= self.cache_size, "Number of input tokens exceeds cache size" | ||
| return self.update(k_new, v_new) |
There was a problem hiding this comment.
Similar to LayerCache, the prefill method in CyclicCache accepts segment_ids but doesn't use them, instead calling update directly. This can lead to an assertion failure if start_ind wasn't initialized at creation. prefill should handle the initialization.
| assert k_new.shape[1] <= self.cache_size, "Number of input tokens exceeds cache size" | |
| return self.update(k_new, v_new) | |
| assert k_new.shape[1] <= self.cache_size, "Number of input tokens exceeds cache size" | |
| if not self.start_ind_initialized: | |
| self._init_start_ind(segment_ids) | |
| return self.update(k_new, v_new) |
| logging.basicConfig(level=logging.INFO) | ||
| logging.info("KV cache utilities are still in development.") | ||
| logging.info("Need further testing for these functions.") |
There was a problem hiding this comment.
Calling logging.basicConfig() within a library is considered bad practice as it can override the logging configuration of the application using the library. Module-level logging on import is also unconventional. The information about the development status is already present in the docstring. It's recommended to remove these lines.
| class CacheProtocol(nnx.Module): | ||
| """Protocol for KV cache.""" | ||
|
|
||
| def prefill(self, k_new: Array, v_new: Array): ... |
There was a problem hiding this comment.
The signature for prefill in CacheProtocol is prefill(self, k_new: Array, v_new: Array), but the implementations in LayerCache and CyclicCache are prefill(self, k_new: Array, v_new: Array, segment_ids: Array). This violates the Liskov Substitution Principle and will cause issues with static analysis and type checking.
Please update the protocol to match the implementation, as segment_ids are necessary for pre-filling the cache correctly.
| def prefill(self, k_new: Array, v_new: Array): ... | |
| def prefill(self, k_new: Array, v_new: Array, segment_ids: Array): ... |
Implements some caching utilities. This is related to #124 and should make kv-caching more consistent across the repo.
Checklist
run_model.pyfor model usage,test_outputs.pyand/ormodel_validation_colab.ipynbfor quality).