@@ -182,13 +182,7 @@ def _default_for_sds(sds):
182182 def _make ():
183183 if "key" in str (sds .dtype ):
184184 base = jax .random .key (0 )
185- return (
186- base
187- if sds .shape == ()
188- else jax .random .split (base , int (np .prod (sds .shape ))).reshape (
189- sds .shape
190- )
191- )
185+ return base if sds .shape == () else jax .random .split (base , int (np .prod (sds .shape ))).reshape (sds .shape )
192186 return jnp .zeros (sds .shape , dtype = sds .dtype )
193187
194188 sharding = getattr (sds , "sharding" , None )
@@ -208,9 +202,7 @@ def _populate_pure_dict_from_partial(abstract_pure, partial_concrete):
208202 return {
209203 k : _populate_pure_dict_from_partial (
210204 v ,
211- partial_concrete .get (k )
212- if isinstance (partial_concrete , dict )
213- else None ,
205+ partial_concrete .get (k ) if isinstance (partial_concrete , dict ) else None ,
214206 )
215207 for k , v in abstract_pure .items ()
216208 }
@@ -243,29 +235,21 @@ def _load_linen_checkpoint_into_nnx(
243235 )
244236 )
245237 restore_args = ocp .checkpoint_utils .construct_restore_args (linen_abstract )
246- restored = ocp .args .PyTreeRestore (
247- item = linen_abstract , restore_args = restore_args , partial_restore = True
248- )
238+ restored = ocp .args .PyTreeRestore (item = linen_abstract , restore_args = restore_args , partial_restore = True )
249239 restored = ckptr .restore (epath .Path (path ), args = restored )
250240 partial_nnx = train_state_nnx .from_linen_checkpoint_dict (restored )
251241 return _populate_pure_dict_from_partial (nnx_abstract_pure , partial_nnx )
252242
253243
254244def _rebuild_nnx_with_values (abstract_nnx_state , concrete_weights ):
255245 """Fills each Variable in `abstract_nnx_state` with the matching restored array."""
256- leaves , treedef = jax .tree_util .tree_flatten (
257- abstract_nnx_state , is_leaf = lambda x : isinstance (x , nnx .Variable )
258- )
246+ leaves , treedef = jax .tree_util .tree_flatten (abstract_nnx_state , is_leaf = lambda x : isinstance (x , nnx .Variable ))
259247 concrete = jax .tree_util .tree_leaves (concrete_weights )
260248 if len (leaves ) != len (concrete ):
261249 raise ValueError (
262- f"Params load leaf-count mismatch: { len (leaves )} abstract Variables vs"
263- f" { len (concrete )} restored."
250+ f"Params load leaf-count mismatch: { len (leaves )} abstract Variables vs" f" { len (concrete )} restored."
264251 )
265- new_leaves = [
266- v .replace (value = a ) if isinstance (v , nnx .Variable ) else a
267- for v , a in zip (leaves , concrete )
268- ]
252+ new_leaves = [v .replace (value = a ) if isinstance (v , nnx .Variable ) else a for v , a in zip (leaves , concrete )]
269253 return jax .tree_util .tree_unflatten (treedef , new_leaves )
270254
271255
@@ -284,9 +268,7 @@ def _load_linen_params_into_nnx(
284268 NNX params Variables.
285269 """
286270 max_logging .log (f"Restoring Linen-layout params into NNX state at { path } " )
287- linen_abstract = train_state_nnx .to_linen_checkpoint_dict (
288- {"model" : nnx_params_abstract .to_pure_dict ()}
289- )
271+ linen_abstract = train_state_nnx .to_linen_checkpoint_dict ({"model" : nnx_params_abstract .to_pure_dict ()})
290272 ckptr = ocp .Checkpointer (
291273 ocp .PyTreeCheckpointHandler (
292274 restore_concurrent_gb = checkpoint_storage_concurrent_gb ,
@@ -298,13 +280,9 @@ def _load_linen_params_into_nnx(
298280 restore_args = ocp .checkpoint_utils .construct_restore_args (linen_abstract )
299281 restored = ckptr .restore (
300282 epath .Path (path ),
301- args = ocp .args .PyTreeRestore (
302- item = linen_abstract , restore_args = restore_args , partial_restore = True
303- ),
304- )
305- return _rebuild_nnx_with_values (
306- nnx_params_abstract , restored ["params" ]["params" ]
283+ args = ocp .args .PyTreeRestore (item = linen_abstract , restore_args = restore_args , partial_restore = True ),
307284 )
285+ return _rebuild_nnx_with_values (nnx_params_abstract , restored ["params" ]["params" ])
308286
309287
310288def _load_full_state_from_path (
@@ -804,9 +782,7 @@ def map_to_pspec(data):
804782 checkpoint_manager ,
805783 (EmergencyCheckpointManager , EmergencyReplicatorCheckpointManager ),
806784 ):
807- checkpoint_path = str (
808- checkpoint_manager .directory / str (step ) / "items"
809- )
785+ checkpoint_path = str (checkpoint_manager .directory / str (step ) / "items" )
810786 restored_nnx = _load_linen_checkpoint_into_nnx (
811787 checkpoint_path ,
812788 abstract_unboxed_pre_state ,
@@ -837,9 +813,7 @@ def map_to_pspec(data):
837813 (EmergencyCheckpointManager , EmergencyReplicatorCheckpointManager ),
838814 ):
839815 return (
840- checkpoint_manager .restore (
841- step , args = Composite (state = checkpoint_args )
842- ).state ,
816+ checkpoint_manager .restore (step , args = Composite (state = checkpoint_args )).state ,
843817 None ,
844818 )
845819 # Case 2: Matches if dataset type is "grain" and the data iterator is not a
0 commit comments