Skip to content

Gpt oss#166

Open
Aatman09 wants to merge 4 commits intojax-ml:mainfrom
Aatman09:gpt_oss
Open

Gpt oss#166
Aatman09 wants to merge 4 commits intojax-ml:mainfrom
Aatman09:gpt_oss

Conversation

@Aatman09
Copy link
Copy Markdown
Contributor

Resolves #143

Reference

I have implemented the modeling.py and the params.py code and yet to commit the test_outputs.py and run_model.py

Checklist

  • I have read the Contribution Guidelines and used pre-commit hooks to format this commit.
  • I have added all the necessary unit tests for my change. (run_model.py for model usage, test_outputs.py and/or model_validation_colab.ipynb for quality).
  • (If using an LLM) I have carefully reviewed and removed all superfluous comments or unneeded, commented-out code. Only necessary and functional code remains.
  • I have signed the Contributor License Agreement (CLA).

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @Aatman09, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a new GPT-OSS model architecture within the bonsai framework, implemented using JAX and Flax NNX. It provides the foundational components for the model's layers, attention mechanisms, and a Mixture-of-Experts (MoE) feed-forward network. Additionally, it includes functionality to facilitate the loading of pretrained model weights from safetensors, enabling easy integration and use of pre-existing models.

Highlights

  • New GPT-OSS Model Implementation: Introduced a comprehensive implementation of the GPT-OSS model architecture using JAX and Flax NNX, including its configuration, core layers, and attention mechanisms.
  • Mixture-of-Experts (MoE) Support: Integrated a Mixture-of-Experts (MoE) MLP within the decoder layers, featuring a Top-K Router and Expert modules for efficient conditional computation.
  • Rotary Positional Embeddings (RoPE): Implemented Rotary Positional Embeddings (RoPE) for attention layers, enhancing the model's ability to handle sequence position information.
  • Parameter Loading Utilities: Added a utility module for loading pretrained model weights from safetensors files, including robust key mapping and transformation logic to align with the JAX/NNX model structure.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • bonsai/models/gpt_oss/modeling.py
    • Added the complete GPT-OSS model architecture, including configuration, core components, and the overall model structure.
  • bonsai/models/gpt_oss/params.py
    • Added parameter loading utilities for GPT-OSS models, including key mapping and weight transformation from safetensors.
  • bonsai/models/unet/tests/UNet_segmentation_example.md
    • Removed an emoji from a training log message.
Activity
  • The author, Aatman09, has implemented the core modeling.py and params.py files for the GPT-OSS model.
  • The author noted that test_outputs.py and run_model.py are yet to be committed, indicating ongoing work.
  • The pull request checklist shows that contribution guidelines were read, pre-commit hooks were used, superfluous comments were removed, and the Contributor License Agreement (CLA) was signed.
  • Unit tests for the changes are currently pending, as indicated by the unchecked item in the checklist.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new GptOss model, which is a GPT-style model with a Mixture-of-Experts (MoE) component. The implementation includes the model definition and a script for loading pretrained weights. My review has identified several significant issues in the model definition (modeling.py) that would prevent it from being trained correctly from scratch. These include incorrect parameter initialization and a missing implementation for the MoE auxiliary loss, which is critical for training such models. I have also pointed out some unused parameters and potentially confusing code. The weight loading script (params.py) appears to be well-implemented. Addressing the feedback on the model definition is crucial for the correctness and usability of this new model.

Comment on lines +76 to +79
self.gate_up_proj = nnx.Param(jnp.zeros((self.num_experts, self.hidden_size, 2 * self.expert_dim)))
self.gate_up_proj_bias = nnx.Param(jnp.zeros((self.num_experts, 2 * self.expert_dim)))
self.down_proj = nnx.Param(jnp.zeros((self.num_experts, self.expert_dim, self.hidden_size)))
self.down_proj_bias = nnx.Param(jnp.zeros((self.num_experts, self.hidden_size)))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Parameters are initialized with zeros, which will prevent the model from learning correctly if trained from scratch. These should be initialized using a proper random initializer (e.g., from nnx.initializers).

Additionally, the GptOssExperts module's __init__ method does not accept rngs, which is necessary for random initialization. It should be added to the signature and passed from GptOssMLP.

Comment on lines +210 to +213
hidden_states, _ = self.mlp(hidden_states)
hidden_states = residual + hidden_states

return hidden_states
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The router_scores (the second element of the tuple returned by self.mlp) are discarded here. For training a Mixture-of-Experts model, these scores are crucial for calculating the auxiliary load-balancing loss. They should be returned by this layer and propagated up to the final model output. This will require changing the return signature of this method and updating the call sites in GptOssModel.

Suggested change
hidden_states, _ = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
hidden_states, router_scores = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states, router_scores

self.down_proj = nnx.Param(jnp.zeros((self.num_experts, self.expert_dim, self.hidden_size)))
self.down_proj_bias = nnx.Param(jnp.zeros((self.num_experts, self.hidden_size)))

def __call__(self, hidden_states, router_indices=None, routing_weights=None):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The router_indices parameter is unused within this function. This is confusing and suggests a potential disconnect between the router and the expert layer. If the implementation using the dense routing_weights is correct, please remove the unused router_indices parameter from the signature. If a sparse implementation was intended, this function needs to be updated to use the indices.

Suggested change
def __call__(self, hidden_states, router_indices=None, routing_weights=None):
def __call__(self, hidden_states, routing_weights=None):

self.num_heads * self.head_dim, self.hidden_size, use_bias=config.attention_bias, rngs=rngs
)

self.sinks = nnx.Param(jnp.zeros((self.num_heads,)))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The sinks parameter is initialized here but is not used anywhere in the __call__ method. This appears to be dead code and should be removed if it's not being used.

if mask is not None:
scores = scores + mask

probs = nnx.softmax(scores, axis=-1)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The attention_dropout value from the configuration is not used. Applying dropout to the attention probabilities after the softmax operation is a common regularization technique that is missing here. This should be applied to improve model generalization during training.

@coder0143
Copy link
Copy Markdown
Contributor

@Aatman09 Good to have GPT-OSS here, do add tests for - layers, attention, and final logits and a run_model.py file. Since this is an MoE, do add sharding if possible.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

GPT-OSS

2 participants