Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 110 additions & 5 deletions bonsai/models/umt5/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,20 @@ def __post_init__(self):
)


class LayerCache(nnx.Module):
"""Cache for storing key and value projections for a single layer."""

def __init__(self, config: UMT5Config, batch_size: int, cache_size: int, dtype: jnp.dtype):
cache_shape = (batch_size, cache_size, config.num_heads, config.d_kv)
self.k_cache = nnx.Cache(jnp.zeros(cache_shape, dtype=dtype))
self.v_cache = nnx.Cache(jnp.zeros(cache_shape, dtype=dtype))
self.size = cache_size
self.cur_ind = nnx.Variable(jnp.zeros((), dtype=jnp.int32))


Cache: TypeAlias = list[LayerCache | None]


class T5LayerNorm(nnx.Module):
def __init__(
self,
Expand Down Expand Up @@ -286,6 +300,7 @@ def __call__(
hidden_states: jax.Array,
encoder_hidden_states: jax.Array | None = None,
attention_mask: jax.Array | None = None,
cache: LayerCache | None = None,
):
b, n, c = hidden_states.shape[0], self.n_heads, self.key_value_proj_dim

Expand All @@ -297,6 +312,25 @@ def __call__(
k = self.k(current_states).reshape(b, -1, n, c)
v = self.v(current_states).reshape(b, -1, n, c)

# Apply KV-cache if provided
if cache is not None:
if is_cross_attention:
# For cross-attention, cache encoder K/V once and reuse
if cache.cur_ind.value == 0:
cache.k_cache[...] = k
cache.v_cache[...] = v
k = cache.k_cache[...]
v = cache.v_cache[...]
else:
# For self-attention, build cache incrementally
q_len = q.shape[1]
slice_indices = (0, cache.cur_ind.value, 0, 0)
cache.k_cache[...] = jax.lax.dynamic_update_slice(cache.k_cache[...], k, slice_indices)
cache.v_cache[...] = jax.lax.dynamic_update_slice(cache.v_cache[...], v, slice_indices)
k = cache.k_cache[...]
v = cache.v_cache[...]
cache.cur_ind.value = cache.cur_ind.value + q_len
Comment on lines +330 to +332
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The current implementation retrieves the full cache tensors for keys and values, but doesn't mask the unused padded portions. This can lead to incorrect attention scores and is inefficient as computation is performed over padding, because the softmax will be computed over zero-vectors. To fix this, you should slice the cache to its valid length after updating it.

Suggested change
k = cache.k_cache[...]
v = cache.v_cache[...]
cache.cur_ind.value = cache.cur_ind.value + q_len
new_len = cache.cur_ind.value + q_len
k = jax.lax.dynamic_slice(cache.k_cache[...], (0, 0, 0, 0), (k.shape[0], new_len, k.shape[2], k.shape[3]))
v = jax.lax.dynamic_slice(cache.v_cache[...], (0, 0, 0, 0), (v.shape[0], new_len, v.shape[2], v.shape[3]))
cache.cur_ind.value = new_len


# Attention bias
q_len, k_len = q.shape[1], k.shape[1]
if not self.has_relative_attention_bias:
Expand Down Expand Up @@ -358,11 +392,13 @@ def __call__(
self,
hidden_states: jax.Array,
attention_mask=None,
cache: LayerCache | None = None,
):
normed_hidden_states = self.layer_norm(hidden_states)
attention_output = self.SelfAttention(
normed_hidden_states,
attention_mask=attention_mask,
cache=cache,
)
outputs = hidden_states + attention_output
return outputs
Expand All @@ -389,12 +425,14 @@ def __call__(
hidden_states: jax.Array,
encoder_hidden_states: jax.Array = None,
attention_mask: jax.Array = None,
cache: LayerCache | None = None,
):
normed_hidden_states = self.layer_norm(hidden_states)
attention_output = self.EncDecAttention(
normed_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
cache=cache,
)
return hidden_states + self.dropout(attention_output)

Expand Down Expand Up @@ -423,12 +461,15 @@ def __call__(
attention_mask: jax.Array = None,
encoder_hidden_states: jax.Array = None,
encoder_attention_mask: jax.Array = None,
self_attn_cache: LayerCache | None = None,
cross_attn_cache: LayerCache | None = None,
):
# Apply self-attention layer
hidden_states = fp16_clamp(
self.layer[0](
hidden_states,
attention_mask=attention_mask,
cache=self_attn_cache,
)
)
# Cross-Attention Block
Expand All @@ -439,6 +480,7 @@ def __call__(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
cache=cross_attn_cache,
)
)

Expand Down Expand Up @@ -543,6 +585,8 @@ def __call__(
attention_mask: jax.Array = None,
encoder_hidden_states: jax.Array = None,
encoder_attention_mask: jax.Array = None,
self_attn_cache: Cache | None = None,
cross_attn_cache: Cache | None = None,
):
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = self.dropout(inputs_embeds)
Expand All @@ -552,12 +596,16 @@ def __call__(
input_ids, attention_mask, encoder_hidden_states, encoder_attention_mask, inputs_embeds.dtype
)

for _, layer_module in enumerate(self.block):
for i, layer_module in enumerate(self.block):
layer_self_cache = self_attn_cache[i] if self_attn_cache is not None else None
layer_cross_cache = cross_attn_cache[i] if cross_attn_cache is not None else None
layer_outputs = layer_module(
hidden_states,
attention_mask=causal_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
self_attn_cache=layer_self_cache,
cross_attn_cache=layer_cross_cache,
)
hidden_states = layer_outputs[0]
hidden_states = self.final_layer_norm(hidden_states)
Expand Down Expand Up @@ -629,6 +677,42 @@ def __init__(
rngs=rngs,
)

def init_cache(
self,
batch_size: int,
max_decoder_length: int,
encoder_sequence_length: int,
dtype: jnp.dtype = jnp.float32,
) -> tuple[Cache, Cache]:
"""Initialize KV-caches for decoder self-attention and cross-attention.

Args:
batch_size: Batch size for the cache.
max_decoder_length: Maximum decoder sequence length.
encoder_sequence_length: Encoder sequence length (for cross-attention cache).
dtype: Data type for cache arrays.

Returns:
Tuple of (self_attn_cache, cross_attn_cache), each a list of LayerCache per decoder layer.
"""
decoder_config = copy.deepcopy(self.config)
decoder_config.is_decoder = True
decoder_config.num_layers = self.config.num_decoder_layers

# Self-attention cache grows with generation
self_attn_cache = [
LayerCache(decoder_config, batch_size, max_decoder_length, dtype)
for _ in range(decoder_config.num_layers)
]

# Cross-attention cache is fixed (encoder length)
cross_attn_cache = [
LayerCache(decoder_config, batch_size, encoder_sequence_length, dtype)
for _ in range(decoder_config.num_layers)
]

return self_attn_cache, cross_attn_cache

def __call__(
self,
input_ids: jax.Array | None = None,
Expand Down Expand Up @@ -656,22 +740,23 @@ def __call__(

return decoder_outputs

# TODO(#96): Implement KV Cache for efficient inference
def generate(
self,
input_ids: jax.Array,
attention_mask: jax.Array = None,
max_tokens: int | None = None,
max_new_tokens: int | None = None,
use_cache: bool = True,
) -> jax.Array:
"""Generate sequences using greedy decoding.
"""Generate sequences using greedy decoding with optional KV-caching.

Args:
input_ids: Encoder input ids from tokenizer, shape (batch_size, seq_length)
attention_mask: Encoder attention mask, shape (batch_size, seq_length)
max_tokens: Maximum total length of decoder sequence (including start token).
Takes precedence over max_new_tokens if both are provided.
max_new_tokens: Maximum number of new tokens to generate (excluding start token)
use_cache: Whether to use KV-caching for efficiency (default: True)

Returns:
Generated token ids, shape (batch_size, generated_length)
Expand All @@ -694,17 +779,37 @@ def generate(
batch_size = input_ids.shape[0]
decoder_input_ids = jnp.full((batch_size, 1), self.config.decoder_start_token_id, dtype=jnp.int32)

# Initialize caches if enabled
if use_cache:
encoder_seq_len = encoder_outputs.shape[1]
self_attn_cache, cross_attn_cache = self.init_cache(
batch_size=batch_size,
max_decoder_length=max_length,
encoder_sequence_length=encoder_seq_len,
dtype=encoder_outputs.dtype,
)
else:
self_attn_cache = None
cross_attn_cache = None

# Autoregressive generation loop
for _ in range(max_length - 1):
# Decoder forward pass
if use_cache:
# Only pass the last token when using cache
current_input = decoder_input_ids[:, -1:]
else:
current_input = decoder_input_ids

decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
input_ids=current_input,
encoder_hidden_states=encoder_outputs,
self_attn_cache=self_attn_cache,
cross_attn_cache=cross_attn_cache,
)

# Get logits and select next token (greedy)
logits = self.lm_head(decoder_outputs)
# here use simple greedy, but beem search is recommended
next_token = jnp.argmax(logits[:, -1, :], axis=-1, keepdims=True)

# Append to decoder input
Expand Down