Skip to content

Commit c340436

Browse files
author
maxtext authors
committed
Merge pull request #1654 from AI-Hypercomputer:lukebaumann/pinned_pathwaysutils
PiperOrigin-RevId: 754147076
2 parents a473487 + cfdc3a9 commit c340436

4 files changed

Lines changed: 11 additions & 11 deletions

File tree

MaxText/elastic_train.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def elastic_handler(
118118
with mesh:
119119
data_iterator, _ = create_data_iterator(config, mesh)
120120

121-
step, snapshot = elastic_manager.get_resharded_snapshot(mesh)
121+
step, snapshot_jax_arrays, _ = elastic_manager.get_resharded_snapshot(mesh)
122122

123123
# We do not want to restore from the previous checkpoint but instead
124124
# restore from the host offloaded snapshot.
@@ -143,7 +143,7 @@ def elastic_handler(
143143
checkpoint_manager=None,
144144
)
145145

146-
state = state.replace(**snapshot)
146+
state = state.replace(**snapshot_jax_arrays)
147147
state = state.replace(step=state.step.at[None].set(step))
148148

149149
(
@@ -259,7 +259,7 @@ def train_loop(config, elastic_manager, state=None):
259259

260260
elastic_manager.maybe_snapshot(
261261
step,
262-
snapshot={
262+
snapshot_jax_arrays={
263263
"params": state.params,
264264
"opt_state": state.opt_state,
265265
},
@@ -314,7 +314,7 @@ def train_loop(config, elastic_manager, state=None):
314314

315315
elastic_manager.maybe_snapshot(
316316
step=step,
317-
snapshot={
317+
snapshot_jax_arrays={
318318
"params": state.params,
319319
"opt_state": state.opt_state,
320320
},
@@ -323,7 +323,7 @@ def train_loop(config, elastic_manager, state=None):
323323

324324
ret = elastic_manager.maybe_reshard_up(
325325
step=step,
326-
snapshot={
326+
snapshot_jax_arrays={
327327
"params": state.params,
328328
"opt_state": state.opt_state,
329329
},

requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,5 +38,5 @@ transformers
3838
mlperf-logging@git+https://github.com/mlperf/logging.git
3939
google-jetstream@git+https://github.com/AI-Hypercomputer/JetStream.git
4040
jsonlines
41-
pathwaysutils@git+https://github.com/AI-Hypercomputer/pathways-utils.git
42-
omegaconf
41+
pathwaysutils==0.1.1
42+
omegaconf

requirements_with_jax_stable_stack.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ transformers
1818
mlperf-logging@git+https://github.com/mlperf/logging.git
1919
google-jetstream@git+https://github.com/AI-Hypercomputer/JetStream.git
2020
jsonlines
21-
pathwaysutils@git+https://github.com/AI-Hypercomputer/pathways-utils.git
21+
pathwaysutils==0.1.1
2222
google-api-python-client
2323
omegaconf
24-
jaxtyping
24+
jaxtyping

requirements_with_jax_stable_stack_0_5_2_pipreqs.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ omegaconf==2.3.0
2727
optax==0.2.4
2828
orbax==0.1.9
2929
pandas==2.2.3
30-
pathwaysutils==0.1.0
30+
pathwaysutils==0.1.1
3131
# Removing due to conflicts during build
3232
# protobuf==3.20.3
3333
protobuf
@@ -53,4 +53,4 @@ tqdm==4.67.1
5353
transformer_engine==2.1.0
5454
transformers==4.51.3
5555
trl==0.16.1
56-
urllib3==2.4.0
56+
urllib3==2.4.0

0 commit comments

Comments
 (0)