Skip to content

Commit c40004a

Browse files
committed
fix: update train_state_nnx import path after #3929 relocation
PR #3929 moved src/maxtext/layers/train_state_nnx.py to src/maxtext/common/train_state_nnx.py. Update remaining imports in diloco.py and three test files so PR11 still imports correctly.
1 parent e6c5bee commit c40004a

5 files changed

Lines changed: 15 additions & 41 deletions

File tree

src/maxtext/common/checkpointing.py

Lines changed: 11 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -182,13 +182,7 @@ def _default_for_sds(sds):
182182
def _make():
183183
if "key" in str(sds.dtype):
184184
base = jax.random.key(0)
185-
return (
186-
base
187-
if sds.shape == ()
188-
else jax.random.split(base, int(np.prod(sds.shape))).reshape(
189-
sds.shape
190-
)
191-
)
185+
return base if sds.shape == () else jax.random.split(base, int(np.prod(sds.shape))).reshape(sds.shape)
192186
return jnp.zeros(sds.shape, dtype=sds.dtype)
193187

194188
sharding = getattr(sds, "sharding", None)
@@ -208,9 +202,7 @@ def _populate_pure_dict_from_partial(abstract_pure, partial_concrete):
208202
return {
209203
k: _populate_pure_dict_from_partial(
210204
v,
211-
partial_concrete.get(k)
212-
if isinstance(partial_concrete, dict)
213-
else None,
205+
partial_concrete.get(k) if isinstance(partial_concrete, dict) else None,
214206
)
215207
for k, v in abstract_pure.items()
216208
}
@@ -243,29 +235,21 @@ def _load_linen_checkpoint_into_nnx(
243235
)
244236
)
245237
restore_args = ocp.checkpoint_utils.construct_restore_args(linen_abstract)
246-
restored = ocp.args.PyTreeRestore(
247-
item=linen_abstract, restore_args=restore_args, partial_restore=True
248-
)
238+
restored = ocp.args.PyTreeRestore(item=linen_abstract, restore_args=restore_args, partial_restore=True)
249239
restored = ckptr.restore(epath.Path(path), args=restored)
250240
partial_nnx = train_state_nnx.from_linen_checkpoint_dict(restored)
251241
return _populate_pure_dict_from_partial(nnx_abstract_pure, partial_nnx)
252242

253243

254244
def _rebuild_nnx_with_values(abstract_nnx_state, concrete_weights):
255245
"""Fills each Variable in `abstract_nnx_state` with the matching restored array."""
256-
leaves, treedef = jax.tree_util.tree_flatten(
257-
abstract_nnx_state, is_leaf=lambda x: isinstance(x, nnx.Variable)
258-
)
246+
leaves, treedef = jax.tree_util.tree_flatten(abstract_nnx_state, is_leaf=lambda x: isinstance(x, nnx.Variable))
259247
concrete = jax.tree_util.tree_leaves(concrete_weights)
260248
if len(leaves) != len(concrete):
261249
raise ValueError(
262-
f"Params load leaf-count mismatch: {len(leaves)} abstract Variables vs"
263-
f" {len(concrete)} restored."
250+
f"Params load leaf-count mismatch: {len(leaves)} abstract Variables vs" f" {len(concrete)} restored."
264251
)
265-
new_leaves = [
266-
v.replace(value=a) if isinstance(v, nnx.Variable) else a
267-
for v, a in zip(leaves, concrete)
268-
]
252+
new_leaves = [v.replace(value=a) if isinstance(v, nnx.Variable) else a for v, a in zip(leaves, concrete)]
269253
return jax.tree_util.tree_unflatten(treedef, new_leaves)
270254

271255

@@ -284,9 +268,7 @@ def _load_linen_params_into_nnx(
284268
NNX params Variables.
285269
"""
286270
max_logging.log(f"Restoring Linen-layout params into NNX state at {path}")
287-
linen_abstract = train_state_nnx.to_linen_checkpoint_dict(
288-
{"model": nnx_params_abstract.to_pure_dict()}
289-
)
271+
linen_abstract = train_state_nnx.to_linen_checkpoint_dict({"model": nnx_params_abstract.to_pure_dict()})
290272
ckptr = ocp.Checkpointer(
291273
ocp.PyTreeCheckpointHandler(
292274
restore_concurrent_gb=checkpoint_storage_concurrent_gb,
@@ -298,13 +280,9 @@ def _load_linen_params_into_nnx(
298280
restore_args = ocp.checkpoint_utils.construct_restore_args(linen_abstract)
299281
restored = ckptr.restore(
300282
epath.Path(path),
301-
args=ocp.args.PyTreeRestore(
302-
item=linen_abstract, restore_args=restore_args, partial_restore=True
303-
),
304-
)
305-
return _rebuild_nnx_with_values(
306-
nnx_params_abstract, restored["params"]["params"]
283+
args=ocp.args.PyTreeRestore(item=linen_abstract, restore_args=restore_args, partial_restore=True),
307284
)
285+
return _rebuild_nnx_with_values(nnx_params_abstract, restored["params"]["params"])
308286

309287

310288
def _load_full_state_from_path(
@@ -804,9 +782,7 @@ def map_to_pspec(data):
804782
checkpoint_manager,
805783
(EmergencyCheckpointManager, EmergencyReplicatorCheckpointManager),
806784
):
807-
checkpoint_path = str(
808-
checkpoint_manager.directory / str(step) / "items"
809-
)
785+
checkpoint_path = str(checkpoint_manager.directory / str(step) / "items")
810786
restored_nnx = _load_linen_checkpoint_into_nnx(
811787
checkpoint_path,
812788
abstract_unboxed_pre_state,
@@ -837,9 +813,7 @@ def map_to_pspec(data):
837813
(EmergencyCheckpointManager, EmergencyReplicatorCheckpointManager),
838814
):
839815
return (
840-
checkpoint_manager.restore(
841-
step, args=Composite(state=checkpoint_args)
842-
).state,
816+
checkpoint_manager.restore(step, args=Composite(state=checkpoint_args)).state,
843817
None,
844818
)
845819
# Case 2: Matches if dataset type is "grain" and the data iterator is not a

src/maxtext/trainers/diloco/diloco.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
import optax
3535

3636
from maxtext.configs import pyconfig
37-
from maxtext.layers.train_state_nnx import TrainStateNNX
37+
from maxtext.common.train_state_nnx import TrainStateNNX
3838

3939
Batch = Any
4040
Params = PyTree

tests/integration/diloco_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
import pytest
3131

3232
from maxtext.configs.pyconfig import initialize_pydantic
33-
from maxtext.layers.train_state_nnx import TrainStateNNX
33+
from maxtext.common.train_state_nnx import TrainStateNNX
3434
from maxtext.trainers.pre_train.train_compile import main as train_compile_main
3535
from maxtext.trainers.diloco import diloco
3636
from tests.utils.test_helpers import get_test_config_path

tests/unit/maxtext_utils_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
from maxtext.utils import maxtext_utils
4343
from maxtext.utils import maxtext_utils_nnx
4444
from maxtext.utils import model_creation_utils
45-
from maxtext.layers import train_state_nnx
45+
from maxtext.common import train_state_nnx
4646
from maxtext.utils import sharding
4747
from maxtext.utils.sharding import assert_params_sufficiently_sharded, get_formatted_sharding_annotations
4848
from tests.utils.test_helpers import get_test_config_path

tests/unit/state_dtypes_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from maxtext.optimizers import optimizers
2929
from maxtext.utils import maxtext_utils
3030
from maxtext.utils import model_creation_utils
31-
from maxtext.layers import train_state_nnx
31+
from maxtext.common import train_state_nnx
3232
from tests.utils.test_helpers import get_test_config_path
3333

3434
Transformer = models.transformer_as_linen

0 commit comments

Comments
 (0)