Multi-tier checkpointing + orbax replicator#1332
Conversation
| FLAGS = flags.FLAGS | ||
|
|
||
| flags.DEFINE_integer( | ||
| "assume_data_parallelism", |
There was a problem hiding this comment.
future follow up. I think orbax has a way of figuring this out automatically since it also needs to know this info. Orbax requires you to specify the batch dimension afair so it can know this.
There was a problem hiding this comment.
MaxText sets it to the number of slices. However it may not be correct if there is intra-slice DDP, so we plan to make it configurable.
| FLAGS = flags.FLAGS | ||
|
|
||
| flags.DEFINE_integer( | ||
| "assume_data_parallelism", |
There was a problem hiding this comment.
MaxText sets it to the number of slices. However it may not be correct if there is intra-slice DDP, so we plan to make it configurable.
|
This pull request has been automatically marked as stale because it has been inactive for 60 days. It will be closed in 7 days if no further activity occurs. If you would like to continue working on this, please remove the |
|
This pull request was closed because it has been inactive for more than 7 days since being marked as stale. Please feel free to reopen it if you would like to continue. |
Integrate multi-tier checkpointer + orbax replicator into axlearn