Skip to content

Support grain data checkpoint for elastic training#3673

Open
aireenmei wants to merge 1 commit intomainfrom
aireen/elastic_data
Open

Support grain data checkpoint for elastic training#3673
aireenmei wants to merge 1 commit intomainfrom
aireen/elastic_data

Conversation

@aireenmei
Copy link
Copy Markdown
Collaborator

@aireenmei aireenmei commented Apr 15, 2026

Description

  • migrate RemoteIterator to colocated python class
  • Add checkpointing logic to RemoteIterator, so data iterator in the colocated sidecar writes checkpoint to the checkpoint path, prevent sending data iterator state to the controller
  • Add grain.ElasticIterator support controlled by flag grain_use_elastic_iterator, can be used with or without elastic training, with or without colocated_python. This class allows recovering checkpoint with a dynamic scale (up or down), with the following limitations, future work will loose the limitations:
    (1). Only arrayrecord files are supported, parquet or tfrecord are not supported
    (2). Does not support many-to-one transformations, including packing, filtering
    (3). Does not support mixing datasets

Tests

Tested on Pathways saving and restoring data iterator checkpoints with different # of v5e-32 slices
jobset.yaml

With colocated_python

colocated_python_data_input=true colocated_python_checkpointing=true grain_use_elastic_iterator=true, checkpoints in gs://aireenmei-multipod/pathways-v5e/20260424_grain_elastic_2/aireen-pathways-v5e/checkpoints:

  • Start with 1 slice, set steps=25, checkpoint at step 0, 10, 20, 24: log (log confirms Num devices: 32),
  • Resume with 2 slices, set steps=45, checkpoint at step 30, 40, 44: log(log confirms Num devices: 64),
  • Resume with 1 slice, set steps=65, checkpoint at step 50, 60, 64: log(log confirms Num devices: 32)

Without colocated_python

colocated_python_data_input=false colocated_python_checkpointing=false grain_use_elastic_iterator=true, checkpoints in gs://aireenmei-multipod/pathways-v5e/20260424_grain_elastic_3/aireen-pathways-v5e/checkpoints:

  • Start with 1 slice, set steps=25, checkpoint at step 0, 10, 20, 24: log (log confirms Num devices: 32),
  • Resume with 2 slices, set steps=45, checkpoint at step 30, 40, 44: log(log confirms Num devices: 64),
  • Resume with 1 slice, set steps=65, checkpoint at step 50, 60, 64: log(log confirms Num devices: 32)

Index verification

Inspecting the checkpoints in gs://aireenmei-multipod/pathways-v5e/20260424_grain_elastic_2/aireen-pathways-v5e/checkpoints/{step}/iter/process_0.json

  • Step 20: "global_next_index": 672
  • Step 24: "global_next_index": 800
  • Step 30: "global_next_index": 1184
  • Step 20-24 is on 1 slice, 32 devices, batch_size=32, matches (800-672) / 32 = 4 steps
  • Step 24-30 is on 2 slices, 64 devices, batch_size=64, matches (1184-800) / 64 = 6 steps

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link
Copy Markdown

codecov Bot commented Apr 15, 2026

@aireenmei aireenmei force-pushed the aireen/elastic_data branch 2 times, most recently from 7bd4359 to 947c587 Compare April 16, 2026 19:35
@aireenmei aireenmei force-pushed the aireen/elastic_data branch 2 times, most recently from e8e9f36 to 5febf80 Compare April 24, 2026 18:55
@github-actions
Copy link
Copy Markdown

🤖 Hi @aireenmei, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions
Copy link
Copy Markdown

🤖 I'm sorry @aireenmei, but I was unable to process your request. Please see the logs for more details.

@github-actions
Copy link
Copy Markdown

🤖 Hi @aireenmei, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@aireenmei aireenmei requested a review from igorts-git as a code owner April 24, 2026 19:19
@github-actions
Copy link
Copy Markdown

🤖 I'm sorry @aireenmei, but I was unable to process your request. Please see the logs for more details.

@lukebaumann
Copy link
Copy Markdown
Collaborator

How are you confirming that the dataset is resumed from the correct index?

return item

# ElasticIterator: every process reads the same shared `process_0.json`.
if isinstance(item, ElasticIterator):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want ElasitcIterator to also be a RemoteIterator. Is that happening?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to support ElasitcIterator in both regular (Pathways, mcJAX) and colocated python environments. Because ElasitcIterator allows flexible chip counts, it has use case for regular customers who may change chip counts while keeping training progress. This specific path is for non-colocated cases (including Pathways and mcJAX), the Pathways + colocated case is handled by RemoteIteratorWrapper in the lines above. Let me improve the comments

Comment thread src/maxtext/common/checkpointing.py Outdated
process_index = jax.process_index() + i * jax.process_count()
grain_iters_to_save.append((data_iter.local_iterator, process_index, process_count_total))
save_args_composite["iter"] = GrainCheckpointSave(item=grain_iters_to_save)
if isinstance(data_iterator[0], RemoteIteratorWrapper):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you pull out what this actually is just for easier readability?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I refactored and removed "[0]". Hopefully better readability now


def save_state(self, step):
step_array = jnp.full(self.dummy_array.shape, step, dtype=jnp.int32)
step_array = jax.device_put(step_array, self.cpu_sharding)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do need any specialization?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain specialization?

@aireenmei
Copy link
Copy Markdown
Collaborator Author

How are you confirming that the dataset is resumed from the correct index?

Good question. I just added a section "Index verification" in PR description.

@aireenmei aireenmei force-pushed the aireen/elastic_data branch from 5febf80 to c8d0b77 Compare April 24, 2026 22:01
@aireenmei aireenmei force-pushed the aireen/elastic_data branch from c8d0b77 to 70e42c0 Compare April 24, 2026 22:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants