Support grain data checkpoint for elastic training#3673
Support grain data checkpoint for elastic training#3673
Conversation
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
7bd4359 to
947c587
Compare
e8e9f36 to
5febf80
Compare
|
🤖 Hi @aireenmei, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
|
🤖 I'm sorry @aireenmei, but I was unable to process your request. Please see the logs for more details. |
|
🤖 Hi @aireenmei, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
|
🤖 I'm sorry @aireenmei, but I was unable to process your request. Please see the logs for more details. |
|
How are you confirming that the dataset is resumed from the correct index? |
| return item | ||
|
|
||
| # ElasticIterator: every process reads the same shared `process_0.json`. | ||
| if isinstance(item, ElasticIterator): |
There was a problem hiding this comment.
We want ElasitcIterator to also be a RemoteIterator. Is that happening?
There was a problem hiding this comment.
We want to support ElasitcIterator in both regular (Pathways, mcJAX) and colocated python environments. Because ElasitcIterator allows flexible chip counts, it has use case for regular customers who may change chip counts while keeping training progress. This specific path is for non-colocated cases (including Pathways and mcJAX), the Pathways + colocated case is handled by RemoteIteratorWrapper in the lines above. Let me improve the comments
| process_index = jax.process_index() + i * jax.process_count() | ||
| grain_iters_to_save.append((data_iter.local_iterator, process_index, process_count_total)) | ||
| save_args_composite["iter"] = GrainCheckpointSave(item=grain_iters_to_save) | ||
| if isinstance(data_iterator[0], RemoteIteratorWrapper): |
There was a problem hiding this comment.
Can you pull out what this actually is just for easier readability?
There was a problem hiding this comment.
I refactored and removed "[0]". Hopefully better readability now
|
|
||
| def save_state(self, step): | ||
| step_array = jnp.full(self.dummy_array.shape, step, dtype=jnp.int32) | ||
| step_array = jax.device_put(step_array, self.cpu_sharding) |
There was a problem hiding this comment.
Do need any specialization?
There was a problem hiding this comment.
Could you explain specialization?
Good question. I just added a section "Index verification" in PR description. |
5febf80 to
c8d0b77
Compare
c8d0b77 to
70e42c0
Compare
Description
(1). Only arrayrecord files are supported, parquet or tfrecord are not supported
(2). Does not support many-to-one transformations, including packing, filtering
(3). Does not support mixing datasets
Tests
Tested on Pathways saving and restoring data iterator checkpoints with different # of v5e-32 slices
jobset.yaml
With colocated_python
colocated_python_data_input=true colocated_python_checkpointing=true grain_use_elastic_iterator=true, checkpoints ings://aireenmei-multipod/pathways-v5e/20260424_grain_elastic_2/aireen-pathways-v5e/checkpoints:Without colocated_python
colocated_python_data_input=false colocated_python_checkpointing=false grain_use_elastic_iterator=true, checkpoints ings://aireenmei-multipod/pathways-v5e/20260424_grain_elastic_3/aireen-pathways-v5e/checkpoints:Index verification
Inspecting the checkpoints in
gs://aireenmei-multipod/pathways-v5e/20260424_grain_elastic_2/aireen-pathways-v5e/checkpoints/{step}/iter/process_0.json"global_next_index": 672"global_next_index": 800"global_next_index": 1184(800-672) / 32 = 4 steps(1184-800) / 64 = 6 stepsChecklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.