Skip to content

cache utilities#155

Draft
chapman20j wants to merge 1 commit intojax-ml:mainfrom
chapman20j:cache
Draft

cache utilities#155
chapman20j wants to merge 1 commit intojax-ml:mainfrom
chapman20j:cache

Conversation

@chapman20j
Copy link
Copy Markdown
Collaborator

Implements some caching utilities. This is related to #124 and should make kv-caching more consistent across the repo.

Checklist

  • I have read the Contribution Guidelines and used pre-commit hooks to format this commit.
  • I have added all the necessary unit tests for my change. (run_model.py for model usage, test_outputs.py and/or model_validation_colab.ipynb for quality).
  • (If using an LLM) I have carefully reviewed and removed all superfluous comments or unneeded, commented-out code. Only necessary and functional code remains.
  • I have signed the Contributor License Agreement (CLA).

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @chapman20j, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces foundational Key-Value (KV) caching utilities for transformer models within the bonsai repository. It provides two distinct caching mechanisms, LayerCache and CyclicCache, along with a common protocol to ensure consistent behavior. The goal is to streamline KV-caching operations and address related issues, setting up a more robust framework for managing model states efficiently.

Highlights

  • New KV Cache Utilities: Introduced a new module bonsai/utils/kv_cache.py which provides common Key-Value (KV) cache strategies for transformer models, aiming for more consistent caching across the repository.
  • LayerCache Implementation: Implemented LayerCache, a pre-allocated cache that updates KV states sequentially, designed for single-layer caching.
  • CyclicCache Implementation: Implemented CyclicCache, another pre-allocated cache that overwrites the oldest entries once its capacity is reached, also for single-layer caching.
  • Cache Protocol and Utilities: Defined a CacheProtocol for a standardized interface across different cache types and included a compute_left_pads utility function.
  • Placeholder Tests Added: A new test file bonsai/utils/tests/kv_cache_test.py was added, containing placeholder test classes and methods for both LayerCache and CyclicCache to ensure future test coverage.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • bonsai/utils/kv_cache.py
    • Added a new file implementing KV cache strategies for transformer models.
    • Defined CacheProtocol as an nnx.Module for cache interface standardization.
    • Implemented LayerCache for sequential KV state storage and updates.
    • Implemented CyclicCache for cyclic KV state storage, overwriting oldest entries.
    • Included compute_left_pads function to determine padding from segment IDs.
    • Added basic example usage for both LayerCache and CyclicCache within the if __name__ == '__main__': block.
  • bonsai/utils/tests/kv_cache_test.py
    • Added a new test file for KV cache utilities.
    • Created TestLayerCache class with placeholder tests for init, prefill, update, and compute_causal_mask.
    • Created TestCyclicCache class with placeholder tests for init, prefill, update, compute_causal_mask, and test_update_after_cache_full.
Activity
  • The author chapman20j opened this pull request.
  • New files bonsai/utils/kv_cache.py and bonsai/utils/tests/kv_cache_test.py were added.
  • The author has completed the contribution guidelines checklist, including adding unit tests (albeit as placeholders for now).
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces KV caching utilities, including a standard LayerCache and a CyclicCache, along with a protocol they adhere to. While the initiative to create consistent caching utilities is great, this PR has some critical issues that need to be addressed before merging.

Most importantly, the associated test file bonsai/utils/tests/kv_cache_test.py contains only empty test stubs, despite the PR checklist indicating that tests have been added. The caching logic, especially for CyclicCache, is complex and includes author notes about needing verification. It is crucial to add comprehensive unit tests.

Additionally, there are several issues in the implementation:

  • The CyclicCache.update method has a bug that can lead to out-of-bounds writes when an update wraps around the buffer.
  • The prefill methods in both cache classes don't correctly handle segment_ids for initialization, which will lead to runtime errors.
  • The CacheProtocol has a method signature mismatch with its implementations.
  • The module configures logging at import time, which is a bad practice for a library.

Please address these points, with the highest priority on adding tests.

Comment on lines +204 to +209
def update(self, k: Array, v: Array):
assert self.start_ind_initialized, "Must initialize start_ind before updating LayerCache"
slice_indices = (0, self.cur_ind[...] % self.cache_size, 0, 0)
self.k_cache[...] = jax.lax.dynamic_update_slice(self.k_cache[...], k, slice_indices)
self.v_cache[...] = jax.lax.dynamic_update_slice(self.v_cache[...], v, slice_indices)
self.cur_ind[...] = self.cur_ind[...] + k.shape[1]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The update method in CyclicCache has a potential out-of-bounds write issue. When k.shape[1] > 1 and the update wraps around the cache buffer, jax.lax.dynamic_update_slice will attempt to write past the end of the cache array. For example, if cache_size=4, cur_ind=3, and k.shape[1]=2, the update will start at index 3 and try to write 2 elements, accessing index 4 which is out of bounds.

While the TODO notes an assumption of single-token updates post-prefill, the code doesn't enforce it. To prevent incorrect behavior, you should either handle the wrap-around case (which would require splitting the update into two dynamic_update_slice calls) or add an assertion to enforce the assumption that updates do not wrap.

Suggested change
def update(self, k: Array, v: Array):
assert self.start_ind_initialized, "Must initialize start_ind before updating LayerCache"
slice_indices = (0, self.cur_ind[...] % self.cache_size, 0, 0)
self.k_cache[...] = jax.lax.dynamic_update_slice(self.k_cache[...], k, slice_indices)
self.v_cache[...] = jax.lax.dynamic_update_slice(self.v_cache[...], v, slice_indices)
self.cur_ind[...] = self.cur_ind[...] + k.shape[1]
def update(self, k: Array, v: Array):
assert self.start_ind_initialized, "Must initialize start_ind before updating LayerCache"
assert self.cur_ind[...] % self.cache_size + k.shape[1] <= self.cache_size, "Update logic does not support wrapping around the buffer yet."
slice_indices = (0, self.cur_ind[...] % self.cache_size, 0, 0)
self.k_cache[...] = jax.lax.dynamic_update_slice(self.k_cache[...], k, slice_indices)
self.v_cache[...] = jax.lax.dynamic_update_slice(self.v_cache[...], v, slice_indices)
self.cur_ind[...] = self.cur_ind[...] + k.shape[1]

Comment on lines +19 to +33
class TestLayerCache(absltest.TestCase):
def setUp(self):
super().setUp()

def test_init(self):
pass

def test_prefill(self):
pass

def test_update(self):
pass

def test_compute_causal_mask(self):
pass
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The test cases for TestLayerCache are empty. The PR description states that unit tests have been added, but the implementations are missing. Given the complexity of KV caching logic, comprehensive unit tests are essential to ensure correctness. Please implement these tests, covering initialization, prefill, update, and causal mask computation.

Comment on lines +36 to +53
class TestCyclicCache(absltest.TestCase):
def setUp(self):
super().setUp()

def test_init(self):
pass

def test_prefill(self):
pass

def test_update(self):
pass

def test_compute_causal_mask(self):
pass

def test_update_after_cache_full(self):
pass
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The test cases for TestCyclicCache are empty. The caching logic here is particularly complex, with wrap-around behavior and tricky causal masking, and the implementation file itself contains TODOs about verifying the logic. It is critical to add comprehensive unit tests covering prefill, single-token updates, multi-token updates (especially edge cases that cause buffer wrapping), and causal mask computation to ensure the implementation is correct.

self.start_ind_initialized = True

def prefill(self, k_new: Array, v_new: Array, segment_ids: Array):
return self.update(k_new, v_new)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The prefill method in LayerCache accepts segment_ids but doesn't use them. It directly calls update, which relies on self.start_ind being initialized. If the cache is initialized without segment_ids, a call to prefill will fail the assertion inside update.

The prefill method should be responsible for initializing start_ind if it hasn't been done already.

Suggested change
return self.update(k_new, v_new)
if not self.start_ind_initialized:
self._init_start_ind(segment_ids)
return self.update(k_new, v_new)

Comment on lines +201 to +202
assert k_new.shape[1] <= self.cache_size, "Number of input tokens exceeds cache size"
return self.update(k_new, v_new)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to LayerCache, the prefill method in CyclicCache accepts segment_ids but doesn't use them, instead calling update directly. This can lead to an assertion failure if start_ind wasn't initialized at creation. prefill should handle the initialization.

Suggested change
assert k_new.shape[1] <= self.cache_size, "Number of input tokens exceeds cache size"
return self.update(k_new, v_new)
assert k_new.shape[1] <= self.cache_size, "Number of input tokens exceeds cache size"
if not self.start_ind_initialized:
self._init_start_ind(segment_ids)
return self.update(k_new, v_new)

Comment on lines +53 to +55
logging.basicConfig(level=logging.INFO)
logging.info("KV cache utilities are still in development.")
logging.info("Need further testing for these functions.")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Calling logging.basicConfig() within a library is considered bad practice as it can override the logging configuration of the application using the library. Module-level logging on import is also unconventional. The information about the development status is already present in the docstring. It's recommended to remove these lines.

class CacheProtocol(nnx.Module):
"""Protocol for KV cache."""

def prefill(self, k_new: Array, v_new: Array): ...
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The signature for prefill in CacheProtocol is prefill(self, k_new: Array, v_new: Array), but the implementations in LayerCache and CyclicCache are prefill(self, k_new: Array, v_new: Array, segment_ids: Array). This violates the Liskov Substitution Principle and will cause issues with static analysis and type checking.

Please update the protocol to match the implementation, as segment_ids are necessary for pre-filling the cache correctly.

Suggested change
def prefill(self, k_new: Array, v_new: Array): ...
def prefill(self, k_new: Array, v_new: Array, segment_ids: Array): ...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant