@@ -182,13 +182,7 @@ def _default_for_sds(sds):
182182 def _make ():
183183 if "key" in str (sds .dtype ):
184184 base = jax .random .key (0 )
185- return (
186- base
187- if sds .shape == ()
188- else jax .random .split (base , int (np .prod (sds .shape ))).reshape (
189- sds .shape
190- )
191- )
185+ return base if sds .shape == () else jax .random .split (base , int (np .prod (sds .shape ))).reshape (sds .shape )
192186 return jnp .zeros (sds .shape , dtype = sds .dtype )
193187
194188 sharding = getattr (sds , "sharding" , None )
@@ -208,9 +202,7 @@ def _populate_pure_dict_from_partial(abstract_pure, partial_concrete):
208202 return {
209203 k : _populate_pure_dict_from_partial (
210204 v ,
211- partial_concrete .get (k )
212- if isinstance (partial_concrete , dict )
213- else None ,
205+ partial_concrete .get (k ) if isinstance (partial_concrete , dict ) else None ,
214206 )
215207 for k , v in abstract_pure .items ()
216208 }
@@ -243,29 +235,21 @@ def _load_linen_checkpoint_into_nnx(
243235 )
244236 )
245237 restore_args = ocp .checkpoint_utils .construct_restore_args (linen_abstract )
246- restored = ocp .args .PyTreeRestore (
247- item = linen_abstract , restore_args = restore_args , partial_restore = True
248- )
238+ restored = ocp .args .PyTreeRestore (item = linen_abstract , restore_args = restore_args , partial_restore = True )
249239 restored = ckptr .restore (epath .Path (path ), args = restored )
250240 partial_nnx = train_state_nnx .from_linen_checkpoint_dict (restored )
251241 return _populate_pure_dict_from_partial (nnx_abstract_pure , partial_nnx )
252242
253243
254244def _rebuild_nnx_with_values (abstract_nnx_state , concrete_weights ):
255245 """Fills each Variable in `abstract_nnx_state` with the matching restored array."""
256- leaves , treedef = jax .tree_util .tree_flatten (
257- abstract_nnx_state , is_leaf = lambda x : isinstance (x , nnx .Variable )
258- )
246+ leaves , treedef = jax .tree_util .tree_flatten (abstract_nnx_state , is_leaf = lambda x : isinstance (x , nnx .Variable ))
259247 concrete = jax .tree_util .tree_leaves (concrete_weights )
260248 if len (leaves ) != len (concrete ):
261249 raise ValueError (
262- f"Params load leaf-count mismatch: { len (leaves )} abstract Variables vs"
263- f" { len (concrete )} restored."
250+ f"Params load leaf-count mismatch: { len (leaves )} abstract Variables vs" f" { len (concrete )} restored."
264251 )
265- new_leaves = [
266- v .replace (value = a ) if isinstance (v , nnx .Variable ) else a
267- for v , a in zip (leaves , concrete )
268- ]
252+ new_leaves = [v .replace (value = a ) if isinstance (v , nnx .Variable ) else a for v , a in zip (leaves , concrete )]
269253 return jax .tree_util .tree_unflatten (treedef , new_leaves )
270254
271255
@@ -284,9 +268,7 @@ def _load_linen_params_into_nnx(
284268 NNX params Variables.
285269 """
286270 max_logging .log (f"Restoring Linen-layout params into NNX state at { path } " )
287- linen_abstract = train_state_nnx .to_linen_checkpoint_dict (
288- {"model" : nnx_params_abstract .to_pure_dict ()}
289- )
271+ linen_abstract = train_state_nnx .to_linen_checkpoint_dict ({"model" : nnx_params_abstract .to_pure_dict ()})
290272 ckptr = ocp .Checkpointer (
291273 ocp .PyTreeCheckpointHandler (
292274 restore_concurrent_gb = checkpoint_storage_concurrent_gb ,
@@ -298,13 +280,9 @@ def _load_linen_params_into_nnx(
298280 restore_args = ocp .checkpoint_utils .construct_restore_args (linen_abstract )
299281 restored = ckptr .restore (
300282 epath .Path (path ),
301- args = ocp .args .PyTreeRestore (
302- item = linen_abstract , restore_args = restore_args , partial_restore = True
303- ),
304- )
305- return _rebuild_nnx_with_values (
306- nnx_params_abstract , restored ["params" ]["params" ]
283+ args = ocp .args .PyTreeRestore (item = linen_abstract , restore_args = restore_args , partial_restore = True ),
307284 )
285+ return _rebuild_nnx_with_values (nnx_params_abstract , restored ["params" ]["params" ])
308286
309287
310288def _load_full_state_from_path (
@@ -388,7 +366,7 @@ def create_orbax_checkpoint_manager(
388366 enable_checkpointing : bool ,
389367 use_async : bool ,
390368 save_interval_steps : int ,
391- dataset_type : None | str = "tfds" ,
369+ dataset_type : None | str = None ,
392370 orbax_logger : Any = None , # pytype: disable=attribute-error
393371 use_ocdbt : bool = True ,
394372 use_zarr3 : bool = True ,
@@ -421,7 +399,7 @@ def create_orbax_checkpoint_manager(
421399 )
422400 }
423401
424- if dataset_type == "grain" :
402+ if dataset_type is not None and dataset_type == "grain" :
425403 item_names += ("iter" ,)
426404 item_handlers ["iter" ] = GrainCheckpointHandler ()
427405
@@ -798,9 +776,7 @@ def map_to_pspec(data):
798776 checkpoint_manager ,
799777 (EmergencyCheckpointManager , EmergencyReplicatorCheckpointManager ),
800778 ):
801- checkpoint_path = str (
802- checkpoint_manager .directory / str (step ) / "items"
803- )
779+ checkpoint_path = str (checkpoint_manager .directory / str (step ) / "items" )
804780 restored_nnx = _load_linen_checkpoint_into_nnx (
805781 checkpoint_path ,
806782 abstract_unboxed_pre_state ,
@@ -831,9 +807,7 @@ def map_to_pspec(data):
831807 (EmergencyCheckpointManager , EmergencyReplicatorCheckpointManager ),
832808 ):
833809 return (
834- checkpoint_manager .restore (
835- step , args = Composite (state = checkpoint_args )
836- ).state ,
810+ checkpoint_manager .restore (step , args = Composite (state = checkpoint_args )).state ,
837811 None ,
838812 )
839813 # Case 2: Matches if dataset type is "grain" and the data iterator is not a
0 commit comments