Skip to content

Commit 95026c0

Browse files
committed
[NNX] Delete Linen (1/4): collapse pure_nnx/enable_nnx/isinstance dispatch to NNX-only
Across the core training/utils/inference/RL/checkpoint-conversion code, statically collapse every pure_nnx / enable_nnx / isinstance(model, nn.Module) branch to the NNX path (the model is always NNX now). No flag reads remain in these files.
1 parent 1830487 commit 95026c0

19 files changed

Lines changed: 474 additions & 1896 deletions

src/maxtext/checkpoint_conversion/standalone_scripts/convert_gpt3_ckpt_from_paxml.py

Lines changed: 30 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
"""
3636

3737
import argparse
38-
import functools
3938
import gc
4039
import os
4140
import sys
@@ -47,11 +46,7 @@
4746
from maxtext.configs import pyconfig
4847
from maxtext.utils.globals import MAXTEXT_PKG_DIR
4948
from maxtext.common import checkpointing
50-
from maxtext.common.common_types import MODEL_MODE_TRAIN
51-
from maxtext.layers import quantizations
5249
from maxtext.common import train_state_nnx
53-
from maxtext.models.models import transformer_as_linen
54-
from maxtext.optimizers import optimizers
5550
from maxtext.utils import max_logging
5651
from maxtext.utils import max_utils
5752
from maxtext.utils import maxtext_utils
@@ -92,23 +87,15 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name
9287
devices_array = maxtext_utils.create_device_mesh(cfg)
9388
mesh = Mesh(devices_array, cfg.mesh_axes)
9489

95-
if cfg.pure_nnx:
96-
rngs = maxtext_utils_nnx.create_nnx_rngs(cfg, rng_key=init_rng)
97-
model = model_creation_utils.from_config(cfg, mesh=mesh, rngs=rngs)
98-
_, tx = train_utils.create_training_optimizer(cfg, model)
99-
_create_model_partial, _ = model_creation_utils.create_nnx_abstract_model(cfg, mesh)
90+
rngs = maxtext_utils_nnx.create_nnx_rngs(cfg, rng_key=init_rng)
91+
model = model_creation_utils.from_config(cfg, mesh=mesh, rngs=rngs)
92+
_, tx = train_utils.create_training_optimizer(cfg, model)
93+
_create_model_partial, _ = model_creation_utils.create_nnx_abstract_model(cfg, mesh)
10094

101-
def init_state_fn():
102-
nnx_model = _create_model_partial()
103-
optimizer = nnx.Optimizer(nnx_model, tx, wrt=nnx.Param)
104-
return train_state_nnx.TrainStateNNX(nnx_model, optimizer)
105-
106-
else:
107-
quant = quantizations.configure_quantization(cfg)
108-
model = transformer_as_linen(cfg, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN)
109-
learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(cfg)
110-
tx = optimizers.get_optimizer(cfg, learning_rate_schedule)
111-
init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, cfg, True, init_rng)
95+
def init_state_fn():
96+
nnx_model = _create_model_partial()
97+
optimizer = nnx.Optimizer(nnx_model, tx, wrt=nnx.Param)
98+
return train_state_nnx.TrainStateNNX(nnx_model, optimizer)
11299

113100
checkpoint_manager = checkpointing.create_orbax_checkpoint_manager(
114101
cfg.checkpoint_dir,
@@ -201,24 +188,18 @@ def init_state_fn():
201188
"['decoder']['decoder_norm']['bias']": (".params.lm.final_ln.bias", None),
202189
}
203190

204-
if cfg.pure_nnx:
205-
# NNX state-tree paths after `nnx.split(TrainStateNNX)`. The state is a
206-
# nested `nnx.State` (dict-like Mapping) with `nnx.Variable` leaves, so
207-
# `jax.tree_util.keystr` produces dict-style entries (`['key']`) plus
208-
# `.value` for the Variable leaf, plus `[idx]` for the optax tuple:
209-
# model params -> ['model']<rest>.value
210-
# adam mu / nu -> ['optimizer']['opt_state'][0]['mu' | 'nu']<rest>.value
211-
# step -> ['optimizer']['step'].value
212-
# opt count -> ['optimizer']['opt_state'][0]['count'].value
213-
state_map = {
214-
"['optimizer']['step'].value": ("step", None),
215-
"['optimizer']['opt_state'][0]['count'].value": ("opt_states_0.no_prefix_0.count", None),
216-
}
217-
else:
218-
state_map = {
219-
".step": ("step", None),
220-
".opt_state.count": ("opt_states_0.no_prefix_0.count", None),
221-
}
191+
# NNX state-tree paths after `nnx.split(TrainStateNNX)`. The state is a
192+
# nested `nnx.State` (dict-like Mapping) with `nnx.Variable` leaves, so
193+
# `jax.tree_util.keystr` produces dict-style entries (`['key']`) plus
194+
# `.value` for the Variable leaf, plus `[idx]` for the optax tuple:
195+
# model params -> ['model']<rest>.value
196+
# adam mu / nu -> ['optimizer']['opt_state'][0]['mu' | 'nu']<rest>.value
197+
# step -> ['optimizer']['step'].value
198+
# opt count -> ['optimizer']['opt_state'][0]['count'].value
199+
state_map = {
200+
"['optimizer']['step'].value": ("step", None),
201+
"['optimizer']['opt_state'][0]['count'].value": ("opt_states_0.no_prefix_0.count", None),
202+
}
222203

223204
def get_layer_prefix(keystr_pax):
224205
# different path format between decoder_layer variable
@@ -231,26 +212,15 @@ def get_layer_prefix(keystr_pax):
231212

232213
for keystr_maxtext, (keystr_pax, transform_fn) in keystr_map.items():
233214
prefix_pax_opt_state = get_layer_prefix(keystr_pax)
234-
if cfg.pure_nnx:
235-
state_map[f"['model']{keystr_maxtext}.value"] = (f"mdl_vars{keystr_pax}", transform_fn)
236-
state_map[f"['optimizer']['opt_state'][0]['mu']{keystr_maxtext}.value"] = (
237-
f"opt_states_0.{prefix_pax_opt_state}.m{keystr_pax}",
238-
transform_fn,
239-
)
240-
state_map[f"['optimizer']['opt_state'][0]['nu']{keystr_maxtext}.value"] = (
241-
f"opt_states_0.{prefix_pax_opt_state}.v{keystr_pax}",
242-
transform_fn,
243-
)
244-
else:
245-
state_map[f".params['params']{keystr_maxtext}"] = (f"mdl_vars{keystr_pax}", transform_fn)
246-
state_map[f".opt_state.mu['params']{keystr_maxtext}"] = (
247-
f"opt_states_0.{prefix_pax_opt_state}.m{keystr_pax}",
248-
transform_fn,
249-
)
250-
state_map[f".opt_state.nu['params']{keystr_maxtext}"] = (
251-
f"opt_states_0.{prefix_pax_opt_state}.v{keystr_pax}",
252-
transform_fn,
253-
)
215+
state_map[f"['model']{keystr_maxtext}.value"] = (f"mdl_vars{keystr_pax}", transform_fn)
216+
state_map[f"['optimizer']['opt_state'][0]['mu']{keystr_maxtext}.value"] = (
217+
f"opt_states_0.{prefix_pax_opt_state}.m{keystr_pax}",
218+
transform_fn,
219+
)
220+
state_map[f"['optimizer']['opt_state'][0]['nu']{keystr_maxtext}.value"] = (
221+
f"opt_states_0.{prefix_pax_opt_state}.v{keystr_pax}",
222+
transform_fn,
223+
)
254224

255225
def verify_fn(key_path, _):
256226
keystr = jax.tree_util.keystr(key_path)
@@ -302,7 +272,7 @@ def map_fn(key_path, value):
302272
max_logging.log("converted state finished")
303273
max_utils.print_mem_stats("converted state finished")
304274

305-
step_value = int(converted_state.optimizer.step.value) if cfg.pure_nnx else converted_state.step
275+
step_value = int(converted_state.optimizer.step.value)
306276
if checkpointing.save_checkpoint(checkpoint_manager, step_value, converted_state):
307277
max_logging.log(f"saved a checkpoint at step {step_value}")
308278
# Upon preemption, exit when and only when all ongoing saves are complete.

src/maxtext/common/checkpointing.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -949,19 +949,14 @@ def maybe_save_checkpoint(checkpoint_manager, state, config, data_iterator, step
949949
if step is not None:
950950
actual_step = int(step)
951951
else:
952-
if config.pure_nnx:
953-
actual_step = int(state.optimizer.step) - 1
954-
else:
955-
# Linen TrainState has .step attribute
956-
actual_step = int(state.step) - 1
952+
actual_step = int(state.optimizer.step) - 1
957953

958954
if checkpoint_manager.latest_step() == actual_step:
959955
max_logging.log(f"Checkpoint for step {actual_step} already exists, skipping save.")
960956
return
961957

962-
if config.pure_nnx:
963-
# Save in the Linen on-disk layout so pure_nnx and Linen checkpoints are interchangeable.
964-
state = train_state_nnx.to_linen_checkpoint_dict(state.to_pure_dict())
958+
# Save in the Linen on-disk layout so pure_nnx and Linen checkpoints are interchangeable.
959+
state = train_state_nnx.to_linen_checkpoint_dict(state.to_pure_dict())
965960

966961
# Determine if a checkpoint save should be forced, overriding the usual `config.checkpoint_period` logic.
967962
# This occurs if this function was called:

0 commit comments

Comments
 (0)