@@ -118,7 +118,7 @@ def elastic_handler(
118118 with mesh :
119119 data_iterator , _ = create_data_iterator (config , mesh )
120120
121- step , snapshot = elastic_manager .get_resharded_snapshot (mesh )
121+ step , snapshot_jax_arrays , _ = elastic_manager .get_resharded_snapshot (mesh )
122122
123123 # We do not want to restore from the previous checkpoint but instead
124124 # restore from the host offloaded snapshot.
@@ -143,7 +143,7 @@ def elastic_handler(
143143 checkpoint_manager = None ,
144144 )
145145
146- state = state .replace (** snapshot )
146+ state = state .replace (** snapshot_jax_arrays )
147147 state = state .replace (step = state .step .at [None ].set (step ))
148148
149149 (
@@ -259,7 +259,7 @@ def train_loop(config, elastic_manager, state=None):
259259
260260 elastic_manager .maybe_snapshot (
261261 step ,
262- snapshot = {
262+ snapshot_jax_arrays = {
263263 "params" : state .params ,
264264 "opt_state" : state .opt_state ,
265265 },
@@ -314,7 +314,7 @@ def train_loop(config, elastic_manager, state=None):
314314
315315 elastic_manager .maybe_snapshot (
316316 step = step ,
317- snapshot = {
317+ snapshot_jax_arrays = {
318318 "params" : state .params ,
319319 "opt_state" : state .opt_state ,
320320 },
@@ -323,7 +323,7 @@ def train_loop(config, elastic_manager, state=None):
323323
324324 ret = elastic_manager .maybe_reshard_up (
325325 step = step ,
326- snapshot = {
326+ snapshot_jax_arrays = {
327327 "params" : state .params ,
328328 "opt_state" : state .opt_state ,
329329 },
0 commit comments