Skip to content

Conversation

@samanklesaria
Copy link
Collaborator

@samanklesaria samanklesaria commented Nov 28, 2025

What does this PR do?

This PR adds a guide that shows some common techniques for working with Flax models during optimization. These include:

  • Calculation of Exponential Moving Averages
  • Optimizing only a low rank addition to certain weights (LORA)
  • Using different learning rates for different parameters to implement the maximal update parameterization
  • Using second order optimizers like LBFGS.
  • Specifying sharding for optimization state that differs from that of parameter state
  • Gradient accumulation

This is a work in progress: the guide will be much further fleshed out over time.

This document emphasizes a style as close to pure jax as possible: to that end, it shows how the flax version of each technique only requires minor deviation from the often more intuitive pure-jax version.

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@samanklesaria samanklesaria force-pushed the opt_cookbook branch 3 times, most recently from c495dc1 to b929529 Compare December 1, 2025 23:53
@samanklesaria samanklesaria force-pushed the opt_cookbook branch 4 times, most recently from d3f39f9 to 34d7c20 Compare December 9, 2025 22:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant