Skip to content

Multi-tier checkpointing + orbax replicator#1332

Closed
ehorning wants to merge 29 commits into
apple:mainfrom
ehorning:orbax-mtc-testing
Closed

Multi-tier checkpointing + orbax replicator#1332
ehorning wants to merge 29 commits into
apple:mainfrom
ehorning:orbax-mtc-testing

Conversation

@ehorning
Copy link
Copy Markdown
Contributor

@ehorning ehorning commented Aug 12, 2025

Integrate multi-tier checkpointer + orbax replicator into axlearn

Comment thread Dockerfile Outdated
Comment thread axlearn/cloud/gcp/jobset_utils.py Outdated
Comment thread axlearn/common/checkpointer_orbax.py Outdated
Comment thread axlearn/common/checkpointer_orbax.py Outdated
Comment thread axlearn/experiments/text/gpt/common.py Outdated
Comment thread axlearn/experiments/text/gpt/fuji.py Outdated
Comment thread axlearn/common/checkpointer_orbax_emergency_replicator.py Outdated
Comment thread axlearn/common/checkpointer_orbax_emergency_replicator.py Outdated
Comment thread axlearn/common/checkpointer_orbax_emergency_replicator.py Outdated
Comment thread axlearn/common/checkpointer_orbax_emergency_replicator.py Outdated
Comment thread axlearn/common/checkpointer_orbax_emergency_replicator.py
Comment thread axlearn/common/checkpointer_orbax_emergency_replicator.py Outdated
Comment thread axlearn/common/checkpointer_orbax_emergency_replicator.py Outdated
Comment thread axlearn/common/checkpointer_orbax_emergency_replicator.py Outdated
Comment thread axlearn/common/checkpointer_orbax_emergency_replicator.py Outdated
Comment thread axlearn/common/checkpointer_orbax_emergency_replicator.py Outdated
Comment thread axlearn/common/checkpointer_orbax_emergency_replicator.py Outdated
Comment thread axlearn/common/checkpointer_orbax_emergency_replicator.py Outdated
Comment thread axlearn/common/checkpointer_orbax_emergency_replicator.py Outdated
Comment thread axlearn/common/checkpointer_orbax_emergency_replicator.py Outdated
Comment thread axlearn/common/checkpointer_orbax_emergency_replicator.py Outdated
Comment thread axlearn/common/checkpointer_orbax_emergency_replicator.py Outdated
Comment thread axlearn/common/checkpointer_orbax_emergency_replicator.py Outdated
Comment thread axlearn/common/checkpointer_orbax_emergency_replicator.py
Comment thread axlearn/cloud/gcp/jobset_utils.py
Comment thread axlearn/common/checkpointer_orbax_emergency_replicator.py
Comment thread axlearn/common/checkpointer_orbax_emergency_replicator.py Outdated
Comment thread axlearn/common/checkpointer_orbax_emergency_replicator.py
Comment thread axlearn/common/checkpointer_orbax_emergency_replicator.py Outdated
Comment thread axlearn/common/checkpointer_orbax_emergency_replicator.py
Comment thread axlearn/common/checkpointer_orbax_emergency_replicator.py Outdated
Comment thread axlearn/common/checkpointer_orbax_emergency_replicator.py Outdated
Comment thread axlearn/common/checkpointer_orbax_emergency_replicator.py Outdated
Comment thread axlearn/common/checkpointer_orbax_emergency_replicator.py Outdated
Comment thread axlearn/common/checkpointer_orbax_emergency_replicator.py Outdated
Comment thread axlearn/common/checkpointer_orbax_emergency_replicator.py Outdated
Comment thread axlearn/common/checkpointer_orbax_emergency_replicator.py Outdated
Comment thread axlearn/cloud/gcp/jobset_utils.py Outdated
Comment thread axlearn/common/checkpointer_orbax_emergency_replicator.py Outdated
Comment thread axlearn/common/checkpointer_orbax_emergency_replicator.py Outdated
Comment thread axlearn/common/checkpointer_orbax_emergency_replicator.py
Comment thread axlearn/experiments/text/gpt/common.py Outdated
@ehorning ehorning marked this pull request as ready for review September 16, 2025 19:39
@ehorning ehorning requested review from a team as code owners September 16, 2025 19:39
FLAGS = flags.FLAGS

flags.DEFINE_integer(
"assume_data_parallelism",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

future follow up. I think orbax has a way of figuring this out automatically since it also needs to know this info. Orbax requires you to specify the batch dimension afair so it can know this.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MaxText sets it to the number of slices. However it may not be correct if there is intra-slice DDP, so we plan to make it configurable.

FLAGS = flags.FLAGS

flags.DEFINE_integer(
"assume_data_parallelism",
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MaxText sets it to the number of slices. However it may not be correct if there is intra-slice DDP, so we plan to make it configurable.

@github-actions
Copy link
Copy Markdown

This pull request has been automatically marked as stale because it has been inactive for 60 days. It will be closed in 7 days if no further activity occurs. If you would like to continue working on this, please remove the stale label or leave a comment.

@github-actions github-actions Bot added the stale label Dec 28, 2025
@github-actions
Copy link
Copy Markdown

github-actions Bot commented Jan 4, 2026

This pull request was closed because it has been inactive for more than 7 days since being marked as stale. Please feel free to reopen it if you would like to continue.

@github-actions github-actions Bot closed this Jan 4, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants