Skip to content

Conversation

@ehorning
Copy link
Contributor

@ehorning ehorning commented Aug 12, 2025

Integrate multi-tier checkpointer + orbax replicator into axlearn

@ehorning ehorning marked this pull request as ready for review September 16, 2025 19:39
@ehorning ehorning requested review from a team as code owners September 16, 2025 19:39
FLAGS = flags.FLAGS

flags.DEFINE_integer(
"assume_data_parallelism",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

future follow up. I think orbax has a way of figuring this out automatically since it also needs to know this info. Orbax requires you to specify the batch dimension afair so it can know this.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MaxText sets it to the number of slices. However it may not be correct if there is intra-slice DDP, so we plan to make it configurable.

FLAGS = flags.FLAGS

flags.DEFINE_integer(
"assume_data_parallelism",

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MaxText sets it to the number of slices. However it may not be correct if there is intra-slice DDP, so we plan to make it configurable.

@github-actions
Copy link

This pull request has been automatically marked as stale because it has been inactive for 60 days. It will be closed in 7 days if no further activity occurs. If you would like to continue working on this, please remove the stale label or leave a comment.

@github-actions github-actions bot added the stale label Dec 28, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants