Conversation
feat(llama32): enhance LlamaShardCfg with activation and logits partition specs
Summary of ChangesHello @Moriyuki-S, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request integrates the Llama 3.2 large language model family into the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request adds a comprehensive implementation of the Llama 3.2 model, including model definition, weight loading from Hugging Face, example run scripts, and a suite of tests for correctness, padding, and sharding. The code is well-structured and follows good practices for writing JAX/Flax models.
I've identified a couple of minor issues: one regarding some dead code in the parameter loading script and another concerning a loose tolerance in one of the output parity tests. Addressing these will improve the code's clarity and robustness. Overall, this is a great contribution.
| if cfg.tie_word_embeddings and "lm_head" in state_dict: | ||
| state_dict["lm_head"]["kernel"] = state_dict["embedder"]["embedding"].T |
There was a problem hiding this comment.
This block of code appears to be unreachable and can be removed for clarity.
- If
cfg.tie_word_embeddingsisTrue,self.lm_headisNonein theLlamamodel, so"lm_head"won't be instate_dict. - If
cfg.tie_word_embeddingsisFalse, the conditioncfg.tie_word_embeddingsis false.
In either scenario, this code block is dead. The model correctly handles tied embeddings through embedder.decode().
| t_logits = self.torch_model(**t_inputs).logits | ||
|
|
||
| n_logits = self.llama_model(n_tokens, segment_ids, cache=None, attn_mask=None) | ||
| np.testing.assert_allclose(n_logits, t_logits.detach().cpu().numpy(), rtol=5e-2, atol=5e-2) |
There was a problem hiding this comment.
The tolerance for this assert_allclose is very high (5e-2), which might hide subtle correctness issues between your implementation and the reference PyTorch model. While floating-point differences are expected across frameworks, such a large tolerance is concerning.
Please investigate if a tighter tolerance can be achieved. For reference, test_forward_logits uses a much tighter tolerance of 1e-4. It would be ideal to have similar precision for the full logit comparison to ensure correctness.
| np.testing.assert_allclose(n_logits, t_logits.detach().cpu().numpy(), rtol=5e-2, atol=5e-2) | |
| np.testing.assert_allclose(n_logits, t_logits.detach().cpu().numpy(), rtol=1e-4, atol=1e-4) |
|
Sharding config is implemented, but testing on multi-device clusters is pending. |
|
If you want to debug/test small models locally, you can just set up the mesh to be |
|
|
||
| cls.batch_size = 4 | ||
| cls.num_input_tokens = 6 | ||
| cls.relaxed_tol = 1e-3 |
There was a problem hiding this comment.
Can we use a different tolerance for each test? 1e-3 is relatively large for individual layers and may not indicate numerical differences.
There was a problem hiding this comment.
+1, since we are computing this in highest precision we want to have a tighter constraints (ex: 1e-5). Also let's not have relaxed_tol as a same class variable but maybe make it explicit in each tests as they may require different contraints?
ex:
assert_close(x, y, rtol=1e-5, atol=1e-3)
There was a problem hiding this comment.
I’ve set the tolerance to 1e-5 for most tests.
For the specific functions below, I measured the actual maximum error and set the tolerance to roughly twice that value as a safety margin.
[ RUN ] TestOutputsLlama32.test_forward_logits
test_forward_logits max_abs 4.5776367e-05 max_rel 0.29375324
[ OK ] TestOutputsLlama32.test_forward_logits
[ RUN ] TestOutputsLlama32.test_full_logits
test_full_logits max_abs 0.00027751923 max_rel 7.5256386Test Code for Verification
def test_full_logits(self):
t_inputs = self._make_torch_input()
n_tokens = jnp.array(t_inputs["input_ids"].detach().cpu().numpy())
attention_mask = jnp.array(t_inputs["attention_mask"].detach().cpu().numpy())
segment_ids = attention_mask.astype(jnp.int32)
with torch.no_grad():
t_logits = self.torch_model(**t_inputs).logits
n_logits = self.llama_model(n_tokens, segment_ids, cache=None, attn_mask=None)
t_logits_np = t_logits.detach().cpu().numpy()
n_logits_np = np.array(n_logits)
diff = np.abs(n_logits_np - t_logits_np)
rel = diff / np.maximum(np.abs(t_logits_np), 1e-8)
print("test_full_logits max_abs", diff.max(), "max_rel", rel.max())
np.testing.assert_allclose(n_logits_np, t_logits_np, rtol=1e-5, atol=1e-5)
def test_forward_logits(self):
t_inputs = self._make_torch_input()
n_tokens = jnp.array(t_inputs["input_ids"].detach().cpu().numpy())
attention_mask = jnp.array(t_inputs["attention_mask"].detach().cpu().numpy())
batch_size, token_len = n_tokens.shape
cache = self.llama_model.init_cache(self.llama_config, batch_size, token_len, generate_steps=1)
with torch.no_grad():
t_logits = self.torch_model(**t_inputs).logits
n_logits, _ = modeling.forward(self.llama_model, cache, n_tokens, self.pad_id, attention_mask=attention_mask)
t_logits_np = t_logits[:, -1].detach().cpu().numpy()
n_logits_np = np.array(n_logits)
diff = np.abs(n_logits_np - t_logits_np)
rel = diff / np.maximum(np.abs(t_logits_np), 1e-8)
print("test_forward_logits max_abs", diff.max(), "max_rel", rel.max())
np.testing.assert_allclose(n_logits_np, t_logits_np, rtol=1e-4, atol=1e-4)| torch.tensor(np.array(jy, dtype=np.float32)), | ||
| ty, | ||
| rtol=self.relaxed_tol, | ||
| atol=self.relaxed_tol, | ||
| check_dtype=False, |
There was a problem hiding this comment.
Could we enable dtype checking?
| atol=self.relaxed_tol, | ||
| check_dtype=False, | ||
| ) | ||
|
|
There was a problem hiding this comment.
Could you add a test here for the attention and/or decoder. This helps identify potential errors in the masking, etc.
| from jax.sharding import AxisType | ||
|
|
||
| from bonsai.models.llama32 import modeling | ||
|
|
There was a problem hiding this comment.
Can you update the tests in this file to use absl for consistency with other models in the repo?
| def _tiny_sharded_config() -> modeling.ModelConfig: | ||
| return modeling.ModelConfig( | ||
| vocab_size=64, | ||
| hidden_size=32, | ||
| intermediate_size=64, | ||
| num_hidden_layers=2, | ||
| num_attention_heads=4, | ||
| head_dim=8, | ||
| num_key_value_heads=2, | ||
| max_position_embeddings=128, | ||
| rms_norm_eps=1e-5, | ||
| rope_theta=10000.0, | ||
| rope_scaling=None, | ||
| tie_word_embeddings=True, | ||
| shd_cfg=modeling.LlamaShardCfg.default(use_fsdp=True, use_tp=True), | ||
| dtype=jnp.float32, | ||
| ) |
There was a problem hiding this comment.
Could combine this with the _tiny_config and accept an argument for whether to use sharding.
Oh. I see that you're doing the first solution in your tests. Are you doing some kind of performance testing? |
| def compute_positions_from_segment_ids(seg_ids: Array) -> Array: | ||
| """Compute position ids from segment ids.""" | ||
| seg_ids = seg_ids.astype(jnp.int32) | ||
| pad_sentinel = 2**30 | ||
|
|
||
| def step(carry: tuple[Array, Array], seg_id: Array) -> tuple[tuple[Array, Array], Array]: | ||
| prev_seg, prev_pos = carry | ||
| is_pad = seg_id == 0 | ||
| is_new = seg_id != prev_seg | ||
| zero = jnp.zeros_like(seg_id) | ||
| pos = jnp.where(is_pad, zero, jnp.where(is_new, zero, prev_pos + 1)) | ||
| pad_val = jnp.full_like(seg_id, pad_sentinel) | ||
| out = jnp.where(is_pad, pad_val, pos) | ||
| new_prev_seg = jnp.where(is_pad, zero, seg_id) | ||
| new_prev_pos = jnp.where(is_pad, zero, pos) | ||
| return (new_prev_seg, new_prev_pos), out | ||
|
|
||
| base = jnp.zeros_like(seg_ids[:, 0]) | ||
| init = (base, base) | ||
| _, out = jax.lax.scan(step, init, seg_ids.T) | ||
| return cast(Array, out.T) |
There was a problem hiding this comment.
Does the implementation still pass tests when using a simpler compute_positions_from_segment_ids? e.g. the code from gemma3 looks like
def compute_positions_from_segment_ids(seg_ids: Array):
return jax.vmap(lambda row: jnp.where(row != 0, jnp.arange(seg_ids.shape[1]) - jnp.argmax(row), 2**30))(seg_ids)There was a problem hiding this comment.
Unfortunately, the simpler implementation fails the tests because the current logic supports Packed Sequences.
For example, with seg_ids = [1, 1, 2, 2]:
-
Simpler implementation:
[0, 1, 2, 3] -
Current implementation:
[0, 1, 0, 1](matches the test expectation)
I agree that the simpler implementation is cleaner. Do you think we should prioritize the simpler implementation, or maintain the current logic to support packing?
There was a problem hiding this comment.
We can go with the current logic then. Thanks for clarifying. Could you mention this in the docstring for the method? Compute position ids from segment ids with support for packed sequences.
|
Hi @Moriyuki-S. Thank you for the nice PR. We left a few comments on the PR. Could you please address these? One other comment is that the folder name should probably be |
jenriver
left a comment
There was a problem hiding this comment.
Thanks for the great PR! This looks solid and looks good to go after just some minor cleanups we pointed out.
Also, please check out contribution guidelines and run the precommit hooks to ensure proper linting and formatting! :)
| @@ -0,0 +1,113 @@ | |||
| # Copyright 2026 The JAX Authors. | |||
There was a problem hiding this comment.
This looks mostly identical to run_model.py, could you merge?
There was a problem hiding this comment.
Sure! I originally planned to implement the base and instruct models in different files, but I agree that a single file is simpler.
I'll merge them and add argument parsing to handle the switch. What do you think about running it like this?
# Run Instruct model (Default: 1B-Instruct)
python3 run_model.py
# Run 3B Base model
python3 run_model.py --size 3B --base
Thanks for the tip! I didn't know about that config setting. I'll try debugging with it.
Could you let me know what specific metrics you are looking for (e.g., throughput, step time)? |
|
@chapman20j @jenriver |
|
Hi @Moriyuki-S. Thanks. Let us know when you want us to take another look at the code. We plan to do more multi-hardware testing in the future. If you have multiple devices and want to run tests we'd appreciate having some benchmarking documented in the readme. You can see some of this in the earlier commits to this repo (https://github.com/jax-ml/bonsai/tree/387f7005b4da2c0b1ed272b59b053a2e4dfef8f6/bonsai/models/qwen3). This shows supported hardware setups and we had profiling in the |
…ids to clarify packed sequences support
…d remove deprecated base model test
…bonsai into feat/implement-llama
|
@chapman20j @jenriver |
|
@chapman20j
|
| emb_weight: PartitionSpec | None = None | ||
| activation: PartitionSpec | None = None | ||
| logits: PartitionSpec | None = None | ||
|
|
||
| # Attention | ||
| q_proj: PartitionSpec | None = None | ||
| k_proj: PartitionSpec | None = None | ||
| v_proj: PartitionSpec | None = None | ||
| o_proj: PartitionSpec | None = None | ||
|
|
||
| attn_logits: PartitionSpec | None = None | ||
| attn_out: PartitionSpec | None = None | ||
|
|
||
| cache: PartitionSpec | None = None | ||
|
|
||
| # MLP | ||
| gate_proj: PartitionSpec | None = None | ||
| up_proj: PartitionSpec | None = None | ||
| down_proj: PartitionSpec | None = None | ||
|
|
||
| # Head | ||
| lm_head: PartitionSpec | None = None |
There was a problem hiding this comment.
I noticed this was updated in Gemma 3, so I aligned this implementation with it. ( #141 )
| cache_shd = self.config.shd_cfg.cache | ||
| k = shard(k, cache_shd) | ||
| v = shard(v, cache_shd) | ||
| cache.k_cache[...] = jax.lax.dynamic_update_slice(cache.k_cache[...], k, slice_indices) | ||
| cache.v_cache[...] = jax.lax.dynamic_update_slice(cache.v_cache[...], v, slice_indices) |
There was a problem hiding this comment.
I fixed the sharding error by explicitly resharding k/v to match the cache layout.
Error Details:
...
Traceback (most recent call last):
File ".../models/llama3_2/tests/test_sharding_llama3_2.py", line 46, in test_forward_sharded_inputs
_ = self.model(tokens, segment_ids, cache, attn_mask=None)
File ".../models/llama3_2/modeling.py", line 620, in __call__
x = layer(x, segment_ids, attn_mask=attn_mask, cache=layer_cache)
...
File ".../models/llama3_2/modeling.py", line 504, in __call__
cache.k_cache[...] = jax.lax.dynamic_update_slice(cache.k_cache[...], k, slice_indices)
jax._src.core.ShardingTypeError: dynamic_update_slice operand sharding must be equal to update sharding, got operand sharding float32[2@fsdp,8,1@tp,4]({Explicit: ('fsdp', 'tp')}) and update sharding float32[2@fsdp,4,1,4@tp]({Explicit: ('fsdp', 'tp')}).819f98b to
6c13f56
Compare
Resolves #115
Reference
Checklist
run_model.pyfor model usage,test_outputs.pyand/ormodel_validation_colab.ipynbfor quality).