Conversation
Summary of ChangesHello @xrhd, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request integrates a new Graph Attention Network (GAT) model into the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a JAX/Flax implementation of the Graph Attention Network (GAT) model, complete with a validation notebook for the Cora dataset and basic unit tests. The implementation correctly supports multi-head attention and aligns with the original GAT architecture. The changes include adding the GAT model files, updating the main README.md to list the new model, and providing a .gitignore for test data. Overall, the code is well-structured and functional, but there are a few areas for improvement regarding documentation consistency, parameter usage, and robustness in error handling.
| dropout_rng: jax.Array, | ||
| dropout_prob: float = 0.6, | ||
| alpha: float = 0.2, | ||
| concat_hidden: bool = True, |
There was a problem hiding this comment.
The concat_hidden parameter is defined in the GAT class's __init__ method but is not used. The concat argument for GATLayer within the hidden layers loop is hardcoded to True. This makes the concat_hidden parameter redundant and potentially misleading. Please either remove it or integrate its functionality.
|
|
||
| ## Validation | ||
|
|
||
| To reproduce the results on the Cora dataset (~83% accuracy): |
| # Masked attention | ||
| # adj is assumed to be 0 for no edge, 1 for edge (including self-loop) | ||
| # We want to mask where adj is 0 | ||
| zero_vec = -9e15 * jnp.ones_like(e) |
There was a problem hiding this comment.
Using a hardcoded large negative number (-9e15) for attention masking can sometimes lead to numerical instability or issues if the floating-point precision changes. It's generally safer and more robust to use jnp.finfo(e.dtype).min to get the smallest representable number for the given data type.
| zero_vec = -9e15 * jnp.ones_like(e) | |
| zero_vec = jnp.finfo(e.dtype).min * jnp.ones_like(e) |
| " from bonsai.models.gat.params import GATConfig\n", | ||
| "except ImportError:\n", | ||
| " try:\n", | ||
| " !pip insetall -e .\n", |
| "print(f\"Test Accuracy: {test_acc:.4f}\")\n", | ||
| "\n", | ||
| "if test_acc >= 0.80:\n", | ||
| " print(\"SUCCESS: Accuracy is above 80%\")\n", |
| from bonsai.models.gat.params import GATConfig | ||
| except ImportError: | ||
| try: | ||
| !pip insetall -e . |
| print("Case 2: Feature dimension mismatch...") | ||
| x_wrong = jax.random.normal(key, (N, F + 1)) | ||
| try: | ||
| model(x_wrong, adj_zero, training=False) | ||
| print("Failure: Model should have raised a dimension mismatch error.") | ||
| except Exception as e: | ||
| print(f"Success: Correctly caught error: {e}") |
There was a problem hiding this comment.
While JAX will raise an error for dimension mismatches, it's good practice to add explicit input shape validation at the beginning of the __call__ method in GATLayer or GAT. This provides clearer error messages to users and makes the model more robust. For example, you could check h.shape[1] == self.in_features and adj.shape == (N, N).
|
Hi @xrhd. Thanks for opening this PR! I see that this is still in a draft stage. Let me know when you're ready for us to take a look. Please also address the comments from gemini-code-assist. |
Resolves #12
This PR introduces the Graph Attention Network (GAT) implementation in bonsai. The implementation follows the original architecture by Veličković et al. (2018) and has been validated against the Cora dataset.
Reference
Checklist
(https://github.com/jax-ml/bonsai/blob/main/CONTRIBUTING.md#linting-and-type-checking) to format this commit.
remains.