Skip to content

Add Llama3.2 (1B & 3B)#138

Open
Moriyuki-S wants to merge 22 commits intojax-ml:mainfrom
Moriyuki-S:feat/implement-llama
Open

Add Llama3.2 (1B & 3B)#138
Moriyuki-S wants to merge 22 commits intojax-ml:mainfrom
Moriyuki-S:feat/implement-llama

Conversation

@Moriyuki-S
Copy link
Copy Markdown
Contributor

@Moriyuki-S Moriyuki-S commented Jan 22, 2026

Resolves #115

Reference

Checklist

  • I have read the Contribution Guidelines and used pre-commit hooks to format this commit.
  • I have added all the necessary unit tests for my change. (run_model.py for model usage, test_outputs.py and/or model_validation_colab.ipynb for quality).
  • (If using an LLM) I have carefully reviewed and removed all superfluous comments or unneeded, commented-out code. Only necessary and functional code remains.
  • I have signed the Contributor License Agreement (CLA).

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @Moriyuki-S, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request integrates the Llama 3.2 large language model family into the bonsai framework, providing a native JAX implementation built with Flax NNX. The changes encompass the complete model definition, a sophisticated mechanism for loading pre-trained weights from Hugging Face, and a suite of tests to ensure functional correctness and proper handling of advanced features like sharding and padding. This addition significantly expands the model capabilities of the bonsai library, offering state-of-the-art LLMs for JAX users.

Highlights

  • New Model Integration: Introduced the Llama 3.2 language model family (1B and 3B, base and instruct variants) into the bonsai library, providing a pure JAX implementation using Flax NNX.
  • Comprehensive Model Architecture: Implemented the full Llama 3.2 architecture, including custom sharding configurations (FSDP, TP), Rotary Position Embeddings (RoPE) with scaling, RMS Normalization, and specialized sharded linear and embedding layers.
  • Hugging Face Weight Loading: Developed a robust mechanism for loading pre-trained Llama 3.2 model weights from Hugging Face safetensors files, including detailed key mapping and necessary tensor transformations.
  • Extensive Testing Suite: Added a comprehensive set of unit tests covering output parity against the Hugging Face PyTorch reference, correct handling of padding and segment IDs, and validation of sharding configurations.
  • Runnable Examples: Provided runnable scripts for both the base and instruct variants of Llama 3.2, demonstrating how to download weights, initialize the model, and perform text generation.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds a comprehensive implementation of the Llama 3.2 model, including model definition, weight loading from Hugging Face, example run scripts, and a suite of tests for correctness, padding, and sharding. The code is well-structured and follows good practices for writing JAX/Flax models.

I've identified a couple of minor issues: one regarding some dead code in the parameter loading script and another concerning a loose tolerance in one of the output parity tests. Addressing these will improve the code's clarity and robustness. Overall, this is a great contribution.

Comment on lines +165 to +166
if cfg.tie_word_embeddings and "lm_head" in state_dict:
state_dict["lm_head"]["kernel"] = state_dict["embedder"]["embedding"].T
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block of code appears to be unreachable and can be removed for clarity.

  • If cfg.tie_word_embeddings is True, self.lm_head is None in the Llama model, so "lm_head" won't be in state_dict.
  • If cfg.tie_word_embeddings is False, the condition cfg.tie_word_embeddings is false.

In either scenario, this code block is dead. The model correctly handles tied embeddings through embedder.decode().

t_logits = self.torch_model(**t_inputs).logits

n_logits = self.llama_model(n_tokens, segment_ids, cache=None, attn_mask=None)
np.testing.assert_allclose(n_logits, t_logits.detach().cpu().numpy(), rtol=5e-2, atol=5e-2)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The tolerance for this assert_allclose is very high (5e-2), which might hide subtle correctness issues between your implementation and the reference PyTorch model. While floating-point differences are expected across frameworks, such a large tolerance is concerning.

Please investigate if a tighter tolerance can be achieved. For reference, test_forward_logits uses a much tighter tolerance of 1e-4. It would be ideal to have similar precision for the full logit comparison to ensure correctness.

Suggested change
np.testing.assert_allclose(n_logits, t_logits.detach().cpu().numpy(), rtol=5e-2, atol=5e-2)
np.testing.assert_allclose(n_logits, t_logits.detach().cpu().numpy(), rtol=1e-4, atol=1e-4)

@Moriyuki-S Moriyuki-S changed the title Add Llama3.2 Add Llama3.2 (1B & 3B) Jan 22, 2026
@Moriyuki-S
Copy link
Copy Markdown
Contributor Author

Sharding config is implemented, but testing on multi-device clusters is pending.

@chapman20j
Copy link
Copy Markdown
Collaborator

If you want to debug/test small models locally, you can just set up the mesh to be (1, 1) or update the number of cpu devices using jax.config.update("jax_num_cpu_devices", 8). Are you running some benchmarks for performance testing?


cls.batch_size = 4
cls.num_input_tokens = 6
cls.relaxed_tol = 1e-3
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use a different tolerance for each test? 1e-3 is relatively large for individual layers and may not indicate numerical differences.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, since we are computing this in highest precision we want to have a tighter constraints (ex: 1e-5). Also let's not have relaxed_tol as a same class variable but maybe make it explicit in each tests as they may require different contraints?

ex:

assert_close(x, y, rtol=1e-5, atol=1e-3)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’ve set the tolerance to 1e-5 for most tests.

For the specific functions below, I measured the actual maximum error and set the tolerance to roughly twice that value as a safety margin.

[ RUN      ] TestOutputsLlama32.test_forward_logits
test_forward_logits max_abs 4.5776367e-05 max_rel 0.29375324
[       OK ] TestOutputsLlama32.test_forward_logits
[ RUN      ] TestOutputsLlama32.test_full_logits
test_full_logits max_abs 0.00027751923 max_rel 7.5256386
Test Code for Verification
    def test_full_logits(self):
        t_inputs = self._make_torch_input()
        n_tokens = jnp.array(t_inputs["input_ids"].detach().cpu().numpy())
        attention_mask = jnp.array(t_inputs["attention_mask"].detach().cpu().numpy())
        segment_ids = attention_mask.astype(jnp.int32)

        with torch.no_grad():
            t_logits = self.torch_model(**t_inputs).logits

        n_logits = self.llama_model(n_tokens, segment_ids, cache=None, attn_mask=None)
        t_logits_np = t_logits.detach().cpu().numpy()
        n_logits_np = np.array(n_logits)
        diff = np.abs(n_logits_np - t_logits_np)
        rel = diff / np.maximum(np.abs(t_logits_np), 1e-8)
        print("test_full_logits max_abs", diff.max(), "max_rel", rel.max())
        np.testing.assert_allclose(n_logits_np, t_logits_np, rtol=1e-5, atol=1e-5)

    def test_forward_logits(self):
        t_inputs = self._make_torch_input()
        n_tokens = jnp.array(t_inputs["input_ids"].detach().cpu().numpy())
        attention_mask = jnp.array(t_inputs["attention_mask"].detach().cpu().numpy())
        batch_size, token_len = n_tokens.shape
        cache = self.llama_model.init_cache(self.llama_config, batch_size, token_len, generate_steps=1)

        with torch.no_grad():
            t_logits = self.torch_model(**t_inputs).logits

        n_logits, _ = modeling.forward(self.llama_model, cache, n_tokens, self.pad_id, attention_mask=attention_mask)
        t_logits_np = t_logits[:, -1].detach().cpu().numpy()
        n_logits_np = np.array(n_logits)
        diff = np.abs(n_logits_np - t_logits_np)
        rel = diff / np.maximum(np.abs(t_logits_np), 1e-8)
        print("test_forward_logits max_abs", diff.max(), "max_rel", rel.max())
        np.testing.assert_allclose(n_logits_np, t_logits_np, rtol=1e-4, atol=1e-4)

Comment on lines +93 to +97
torch.tensor(np.array(jy, dtype=np.float32)),
ty,
rtol=self.relaxed_tol,
atol=self.relaxed_tol,
check_dtype=False,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we enable dtype checking?

atol=self.relaxed_tol,
check_dtype=False,
)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a test here for the attention and/or decoder. This helps identify potential errors in the masking, etc.

from jax.sharding import AxisType

from bonsai.models.llama32 import modeling

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you update the tests in this file to use absl for consistency with other models in the repo?

Comment on lines +11 to +27
def _tiny_sharded_config() -> modeling.ModelConfig:
return modeling.ModelConfig(
vocab_size=64,
hidden_size=32,
intermediate_size=64,
num_hidden_layers=2,
num_attention_heads=4,
head_dim=8,
num_key_value_heads=2,
max_position_embeddings=128,
rms_norm_eps=1e-5,
rope_theta=10000.0,
rope_scaling=None,
tie_word_embeddings=True,
shd_cfg=modeling.LlamaShardCfg.default(use_fsdp=True, use_tp=True),
dtype=jnp.float32,
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could combine this with the _tiny_config and accept an argument for whether to use sharding.

@chapman20j
Copy link
Copy Markdown
Collaborator

If you want to debug/test small models locally, you can just set up the mesh to be (1, 1) or update the number of cpu devices using jax.config.update("jax_num_cpu_devices", 8). Are you running some benchmarks for performance testing?

Oh. I see that you're doing the first solution in your tests. Are you doing some kind of performance testing?

Comment on lines +331 to +351
def compute_positions_from_segment_ids(seg_ids: Array) -> Array:
"""Compute position ids from segment ids."""
seg_ids = seg_ids.astype(jnp.int32)
pad_sentinel = 2**30

def step(carry: tuple[Array, Array], seg_id: Array) -> tuple[tuple[Array, Array], Array]:
prev_seg, prev_pos = carry
is_pad = seg_id == 0
is_new = seg_id != prev_seg
zero = jnp.zeros_like(seg_id)
pos = jnp.where(is_pad, zero, jnp.where(is_new, zero, prev_pos + 1))
pad_val = jnp.full_like(seg_id, pad_sentinel)
out = jnp.where(is_pad, pad_val, pos)
new_prev_seg = jnp.where(is_pad, zero, seg_id)
new_prev_pos = jnp.where(is_pad, zero, pos)
return (new_prev_seg, new_prev_pos), out

base = jnp.zeros_like(seg_ids[:, 0])
init = (base, base)
_, out = jax.lax.scan(step, init, seg_ids.T)
return cast(Array, out.T)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the implementation still pass tests when using a simpler compute_positions_from_segment_ids? e.g. the code from gemma3 looks like

def compute_positions_from_segment_ids(seg_ids: Array):
    return jax.vmap(lambda row: jnp.where(row != 0, jnp.arange(seg_ids.shape[1]) - jnp.argmax(row), 2**30))(seg_ids)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, the simpler implementation fails the tests because the current logic supports Packed Sequences.

For example, with seg_ids = [1, 1, 2, 2]:

  • Simpler implementation: [0, 1, 2, 3]

  • Current implementation: [0, 1, 0, 1] (matches the test expectation)

I agree that the simpler implementation is cleaner. Do you think we should prioritize the simpler implementation, or maintain the current logic to support packing?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can go with the current logic then. Thanks for clarifying. Could you mention this in the docstring for the method? Compute position ids from segment ids with support for packed sequences.

@chapman20j
Copy link
Copy Markdown
Collaborator

Hi @Moriyuki-S. Thank you for the nice PR. We left a few comments on the PR. Could you please address these? One other comment is that the folder name should probably be llama3_2 for clarity. Looking forward to adding llama3.2 to the repo!

Copy link
Copy Markdown
Member

@jenriver jenriver left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the great PR! This looks solid and looks good to go after just some minor cleanups we pointed out.

Also, please check out contribution guidelines and run the precommit hooks to ensure proper linting and formatting! :)

@@ -0,0 +1,113 @@
# Copyright 2026 The JAX Authors.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks mostly identical to run_model.py, could you merge?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure! I originally planned to implement the base and instruct models in different files, but I agree that a single file is simpler.

I'll merge them and add argument parsing to handle the switch. What do you think about running it like this?

# Run Instruct model (Default: 1B-Instruct)
python3 run_model.py

# Run 3B Base model
python3 run_model.py --size 3B --base

@Moriyuki-S
Copy link
Copy Markdown
Contributor Author

@chapman20j

If you want to debug/test small models locally, you can just set up the mesh to be (1, 1) or update the number of cpu devices using jax.config.update("jax_num_cpu_devices", 8).

Thanks for the tip! I didn't know about that config setting. I'll try debugging with it.
Do you think I should also verify the performance on actual multi-GPU/TPU hardware?


Are you running some benchmarks for performance testing?

Could you let me know what specific metrics you are looking for (e.g., throughput, step time)?

@Moriyuki-S
Copy link
Copy Markdown
Contributor Author

@chapman20j @jenriver
Thank you for the great feedback !
I've left a few questions on your comments and will continue working on the other fixes in the meantime.

@chapman20j
Copy link
Copy Markdown
Collaborator

Hi @Moriyuki-S. Thanks. Let us know when you want us to take another look at the code.

We plan to do more multi-hardware testing in the future. If you have multiple devices and want to run tests we'd appreciate having some benchmarking documented in the readme. You can see some of this in the earlier commits to this repo (https://github.com/jax-ml/bonsai/tree/387f7005b4da2c0b1ed272b59b053a2e4dfef8f6/bonsai/models/qwen3). This shows supported hardware setups and we had profiling in the run_model.py file. We opted for simplicity in the run_model.py file for the meantime. For the meantime, we are prioritizing getting working, efficient implementations that can be further optimized in the future.

@Moriyuki-S
Copy link
Copy Markdown
Contributor Author

Moriyuki-S commented Feb 2, 2026

@chapman20j @jenriver
I've addressed your comments. Could you take another look?

@Moriyuki-S
Copy link
Copy Markdown
Contributor Author

@chapman20j
Unfortunately, I don't have a multi-device setup, so I'd like to leave it as is for this PR and hopefully address this in the future!

We plan to do more multi-hardware testing in the future. If you have multiple devices and want to run tests we'd appreciate having some benchmarking documented in the readme. You can see some of this in the earlier commits to this repo (https://github.com/jax-ml/bonsai/tree/387f7005b4da2c0b1ed272b59b053a2e4dfef8f6/bonsai/models/qwen3). This shows supported hardware setups and we had profiling in the run_model.py file. We opted for simplicity in the run_model.py file for the meantime. For the meantime, we are prioritizing getting working, efficient implementations that can be further optimized in the future.

Comment on lines +40 to +61
emb_weight: PartitionSpec | None = None
activation: PartitionSpec | None = None
logits: PartitionSpec | None = None

# Attention
q_proj: PartitionSpec | None = None
k_proj: PartitionSpec | None = None
v_proj: PartitionSpec | None = None
o_proj: PartitionSpec | None = None

attn_logits: PartitionSpec | None = None
attn_out: PartitionSpec | None = None

cache: PartitionSpec | None = None

# MLP
gate_proj: PartitionSpec | None = None
up_proj: PartitionSpec | None = None
down_proj: PartitionSpec | None = None

# Head
lm_head: PartitionSpec | None = None
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed this was updated in Gemma 3, so I aligned this implementation with it. ( #141 )

Comment thread bonsai/models/llama3_2/modeling.py Outdated
Comment on lines +501 to +505
cache_shd = self.config.shd_cfg.cache
k = shard(k, cache_shd)
v = shard(v, cache_shd)
cache.k_cache[...] = jax.lax.dynamic_update_slice(cache.k_cache[...], k, slice_indices)
cache.v_cache[...] = jax.lax.dynamic_update_slice(cache.v_cache[...], v, slice_indices)
Copy link
Copy Markdown
Contributor Author

@Moriyuki-S Moriyuki-S Feb 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fixed the sharding error by explicitly resharding k/v to match the cache layout.

Error Details:

...
Traceback (most recent call last):
  File ".../models/llama3_2/tests/test_sharding_llama3_2.py", line 46, in test_forward_sharded_inputs
    _ = self.model(tokens, segment_ids, cache, attn_mask=None)
  File ".../models/llama3_2/modeling.py", line 620, in __call__
    x = layer(x, segment_ids, attn_mask=attn_mask, cache=layer_cache)
...
  File ".../models/llama3_2/modeling.py", line 504, in __call__
    cache.k_cache[...] = jax.lax.dynamic_update_slice(cache.k_cache[...], k, slice_indices)
jax._src.core.ShardingTypeError: dynamic_update_slice operand sharding must be equal to update sharding, got operand sharding float32[2@fsdp,8,1@tp,4]({Explicit: ('fsdp', 'tp')}) and update sharding float32[2@fsdp,4,1,4@tp]({Explicit: ('fsdp', 'tp')}).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Llama 3.2 (1B & 3B)

3 participants