Skip to content

Commit 5ca18e9

Browse files
committed
[NNX] NNX migration (12/N): delete Linen code paths, classes, and compatibility flags
NNX is now the only model path (PR11 flipped pure_nnx/enable_nnx/pure_nnx_decoder to True), so these flags are no longer dispatch points. Delete the Linen code: - Collapse all flag and isinstance(model, nn.Module) dispatch to the NNX branch across ~22 src files (train.py, maxtext_utils, train_utils, sharding, diloco, maxengine, layerwise_quantization, grpo_trainer, lora_utils, checkpointing, convert_gpt3_ckpt_from_paxml, ...). Zero executable flag reads remain in src. - Delete TransformerLinenPure; the Linen decoder stack Decoder / DecoderLayer / SequentialBlockDecoderLayers (decoders.py 1525->47, only deepstack_process kept); and 28 dead *_as_linen ToLinen wrappers. The wrapped NNX classes are unchanged. - Remove the pure_nnx / enable_nnx / pure_nnx_decoder flags from configs/types.py, base.yml, inference/vllm.yml, and pyconfig. - Delete 21 obsolete Linen-only tests; drop redundant flag args elsewhere. Kept for focused follow-ups: the transformer_as_linen / init_initial_state NNX->Linen bridge (checkpoint-conversion tools), the Linen GRPO reference grpo_loss_fn (torch-gated correctness tests), and the Linen pipeline.py (NNX pipeline parallelism pending PR11.5; test_pipeline_subset skipped). Verified on CPU: NNX unit suite 213 passed / 28 skipped (3 non-regression fails), 5 train_compile AOT cases, nnx_decoders_test 40 passed, lint 10/10, base.yml config-load smoke. Net -5,346 lines across 59 files.
1 parent 5923cb2 commit 5ca18e9

59 files changed

Lines changed: 713 additions & 6014 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

src/maxtext/checkpoint_conversion/standalone_scripts/convert_gpt3_ckpt_from_paxml.py

Lines changed: 27 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
"""
3636

3737
import argparse
38-
import functools
3938
import gc
4039
import os
4140
import sys
@@ -47,11 +46,7 @@
4746
from maxtext.configs import pyconfig
4847
from maxtext.utils.globals import MAXTEXT_PKG_DIR
4948
from maxtext.common import checkpointing
50-
from maxtext.common.common_types import MODEL_MODE_TRAIN
51-
from maxtext.layers import quantizations
5249
from maxtext.common import train_state_nnx
53-
from maxtext.models.models import transformer_as_linen
54-
from maxtext.optimizers import optimizers
5550
from maxtext.utils import max_logging
5651
from maxtext.utils import max_utils
5752
from maxtext.utils import maxtext_utils
@@ -92,23 +87,15 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name
9287
devices_array = maxtext_utils.create_device_mesh(cfg)
9388
mesh = Mesh(devices_array, cfg.mesh_axes)
9489

95-
if cfg.pure_nnx:
96-
rngs = maxtext_utils_nnx.create_nnx_rngs(cfg, rng_key=init_rng)
97-
model = model_creation_utils.from_config(cfg, mesh=mesh, rngs=rngs)
98-
_, tx = train_utils.create_training_optimizer(cfg, model)
99-
_create_model_partial, _ = model_creation_utils.create_nnx_abstract_model(cfg, mesh)
90+
rngs = maxtext_utils_nnx.create_nnx_rngs(cfg, rng_key=init_rng)
91+
model = model_creation_utils.from_config(cfg, mesh=mesh, rngs=rngs)
92+
_, tx = train_utils.create_training_optimizer(cfg, model)
93+
_create_model_partial, _ = model_creation_utils.create_nnx_abstract_model(cfg, mesh)
10094

101-
def init_state_fn():
102-
nnx_model = _create_model_partial()
103-
optimizer = nnx.Optimizer(nnx_model, tx, wrt=nnx.Param)
104-
return train_state_nnx.TrainStateNNX(nnx_model, optimizer)
105-
106-
else:
107-
quant = quantizations.configure_quantization(cfg)
108-
model = transformer_as_linen(cfg, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN)
109-
learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(cfg)
110-
tx = optimizers.get_optimizer(cfg, learning_rate_schedule)
111-
init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, cfg, True, init_rng)
95+
def init_state_fn():
96+
nnx_model = _create_model_partial()
97+
optimizer = nnx.Optimizer(nnx_model, tx, wrt=nnx.Param)
98+
return train_state_nnx.TrainStateNNX(nnx_model, optimizer)
11299

113100
checkpoint_manager = checkpointing.create_orbax_checkpoint_manager(
114101
cfg.checkpoint_dir,
@@ -201,21 +188,15 @@ def init_state_fn():
201188
"['decoder']['decoder_norm']['bias']": (".params.lm.final_ln.bias", None),
202189
}
203190

204-
if cfg.pure_nnx:
205-
# NNX state-tree paths after `nnx.split(TrainStateNNX)`:
206-
# model params -> ['model']<rest>.value
207-
# adam mu / nu -> ['optimizer']['opt_state']['mu' | 'nu']<rest>.value
208-
# step -> ['optimizer']['step'].value
209-
# opt count -> ['optimizer']['opt_state']['count'].value
210-
state_map = {
211-
".optimizer.step.value": ("step", None),
212-
".optimizer.opt_state.count.value": ("opt_states_0.no_prefix_0.count", None),
213-
}
214-
else:
215-
state_map = {
216-
".step": ("step", None),
217-
".opt_state.count": ("opt_states_0.no_prefix_0.count", None),
218-
}
191+
# NNX state-tree paths after `nnx.split(TrainStateNNX)`:
192+
# model params -> ['model']<rest>.value
193+
# adam mu / nu -> ['optimizer']['opt_state']['mu' | 'nu']<rest>.value
194+
# step -> ['optimizer']['step'].value
195+
# opt count -> ['optimizer']['opt_state']['count'].value
196+
state_map = {
197+
".optimizer.step.value": ("step", None),
198+
".optimizer.opt_state.count.value": ("opt_states_0.no_prefix_0.count", None),
199+
}
219200

220201
def get_layer_prefix(keystr_pax):
221202
# different path format between decoder_layer variable
@@ -228,26 +209,15 @@ def get_layer_prefix(keystr_pax):
228209

229210
for keystr_maxtext, (keystr_pax, transform_fn) in keystr_map.items():
230211
prefix_pax_opt_state = get_layer_prefix(keystr_pax)
231-
if cfg.pure_nnx:
232-
state_map[f".model{keystr_maxtext}.value"] = (f"mdl_vars{keystr_pax}", transform_fn)
233-
state_map[f".optimizer.opt_state.mu{keystr_maxtext}.value"] = (
234-
f"opt_states_0.{prefix_pax_opt_state}.m{keystr_pax}",
235-
transform_fn,
236-
)
237-
state_map[f".optimizer.opt_state.nu{keystr_maxtext}.value"] = (
238-
f"opt_states_0.{prefix_pax_opt_state}.v{keystr_pax}",
239-
transform_fn,
240-
)
241-
else:
242-
state_map[f".params['params']{keystr_maxtext}"] = (f"mdl_vars{keystr_pax}", transform_fn)
243-
state_map[f".opt_state.mu['params']{keystr_maxtext}"] = (
244-
f"opt_states_0.{prefix_pax_opt_state}.m{keystr_pax}",
245-
transform_fn,
246-
)
247-
state_map[f".opt_state.nu['params']{keystr_maxtext}"] = (
248-
f"opt_states_0.{prefix_pax_opt_state}.v{keystr_pax}",
249-
transform_fn,
250-
)
212+
state_map[f".model{keystr_maxtext}.value"] = (f"mdl_vars{keystr_pax}", transform_fn)
213+
state_map[f".optimizer.opt_state.mu{keystr_maxtext}.value"] = (
214+
f"opt_states_0.{prefix_pax_opt_state}.m{keystr_pax}",
215+
transform_fn,
216+
)
217+
state_map[f".optimizer.opt_state.nu{keystr_maxtext}.value"] = (
218+
f"opt_states_0.{prefix_pax_opt_state}.v{keystr_pax}",
219+
transform_fn,
220+
)
251221

252222
def verify_fn(key_path, _):
253223
keystr = jax.tree_util.keystr(key_path)
@@ -299,7 +269,7 @@ def map_fn(key_path, value):
299269
max_logging.log("converted state finished")
300270
max_utils.print_mem_stats("converted state finished")
301271

302-
step_value = int(converted_state.optimizer.step.value) if cfg.pure_nnx else converted_state.step
272+
step_value = int(converted_state.optimizer.step.value)
303273
if checkpointing.save_checkpoint(checkpoint_manager, step_value, converted_state):
304274
max_logging.log(f"saved a checkpoint at step {step_value}")
305275
# Upon preemption, exit when and only when all ongoing saves are complete.

src/maxtext/common/checkpointing.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -949,19 +949,14 @@ def maybe_save_checkpoint(checkpoint_manager, state, config, data_iterator, step
949949
if step is not None:
950950
actual_step = int(step)
951951
else:
952-
if config.pure_nnx:
953-
actual_step = int(state.optimizer.step) - 1
954-
else:
955-
# Linen TrainState has .step attribute
956-
actual_step = int(state.step) - 1
952+
actual_step = int(state.optimizer.step) - 1
957953

958954
if checkpoint_manager.latest_step() == actual_step:
959955
max_logging.log(f"Checkpoint for step {actual_step} already exists, skipping save.")
960956
return
961957

962-
if config.pure_nnx:
963-
# Save in the Linen on-disk layout so pure_nnx and Linen checkpoints are interchangeable.
964-
state = train_state_nnx.to_linen_checkpoint_dict(state.to_pure_dict())
958+
# Save in the Linen on-disk layout so pure_nnx and Linen checkpoints are interchangeable.
959+
state = train_state_nnx.to_linen_checkpoint_dict(state.to_pure_dict())
965960

966961
# Determine if a checkpoint save should be forced, overriding the usual `config.checkpoint_period` logic.
967962
# This occurs if this function was called:

src/maxtext/configs/base.yml

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1168,11 +1168,6 @@ position_id_per_seconds: 25
11681168
# Example: "8,8" to use a 8x8 subgrid (64 chips) of a full pod (16x16) of trillium.
11691169
subslice_shape: ""
11701170

1171-
# NNX
1172-
enable_nnx: true
1173-
pure_nnx_decoder: true
1174-
pure_nnx: true
1175-
11761171
################################## Qwen3-Next Specific Configs ##################################
11771172
# Kernel size for the 1D convolution in the Gated Delta Net
11781173
gdn_conv_kernel_dim: 4

src/maxtext/configs/inference/vllm.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@ base_config: "base.yml"
1616
attention: "vllm_rpa"
1717
model_call_mode: "inference"
1818

19-
# NNX required for vLLM integration
20-
enable_nnx: true
2119
# Avoid re-initializing JAX distributed system when using vLLM
2220
skip_jax_distributed_system: true
2321
# Scanned layers are not supported with vLLM integration

src/maxtext/configs/pyconfig_deprecated.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -193,8 +193,7 @@ def validate_expert_shard_attention_option(expert_shard_attention_option: str) -
193193
)
194194

195195

196-
def validate_vocab_tiling(num_vocab_tiling: int, per_device_batch_size: int, max_target_length: int, enable_nnx: bool):
197-
del enable_nnx # NNX vocab tiling supported via vocab_tiling_nnx_loss in vocabulary_tiling.py
196+
def validate_vocab_tiling(num_vocab_tiling: int, per_device_batch_size: int, max_target_length: int):
198197
if (per_device_batch_size * max_target_length) % num_vocab_tiling != 0:
199198
raise ValueError("Per device batch size times sequence length should be divisible by the number of vocab tiles.")
200199

@@ -238,9 +237,7 @@ def validate_keys(keys):
238237
validate_model_call_mode(keys["model_call_mode"])
239238
validate_prefill_and_target_lengths(keys["max_prefill_predict_length"], keys["max_target_length"])
240239
validate_rope_type(keys["rope_type"])
241-
validate_vocab_tiling(
242-
keys["num_vocab_tiling"], keys["per_device_batch_size"], keys["max_target_length"], keys["enable_nnx"]
243-
)
240+
validate_vocab_tiling(keys["num_vocab_tiling"], keys["per_device_batch_size"], keys["max_target_length"])
244241
if keys["enable_rampup_batch_size"]:
245242
validate_rampup_batch_size(
246243
keys["per_device_batch_size_start"],

src/maxtext/configs/types.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -895,11 +895,8 @@ class HardwareAndMesh(BaseModel):
895895
CustomRule.DEFAULT, description="Customized mesh and logical rules for granularity."
896896
)
897897
allow_split_physical_axes: bool = Field(False, description="Allow splitting physical axes for device mesh creation.")
898-
enable_nnx: bool = Field(False, description="Whether to use NNX for model definition.")
899898
optimize_mesh_for_tpu_v6e: bool = Field(False, description="Apply transformations to the mesh for TPU v6e.")
900899
shardy: bool = Field(True, description="Whether to use shardy XLA backend.")
901-
pure_nnx_decoder: bool = Field(False, description="Whether to enable pure NNX decoder.")
902-
pure_nnx: bool = Field(False, description="Whether to enable pure NNX mode.")
903900
remove_size_one_mesh_axis_from_type: bool = Field(
904901
True, description="Whether to remove size one mesh axis from type through jax.config."
905902
)
@@ -2498,8 +2495,6 @@ def validate_and_set_hlo_dump_defaults():
24982495
if self.distill_beta > 0.0:
24992496
if not self.scan_layers:
25002497
raise ValueError("a value of self.distill_beta > 0.0 requires self.scan_layers = True")
2501-
if not self.enable_nnx:
2502-
raise ValueError("a value of self.distill_beta > 0.0 requires self.enable_nnx = True")
25032498

25042499
# Validate distillation schedule parameters
25052500
if self.distill_alpha_end is not None and not 0.0 <= self.distill_alpha_end <= 1.0:

0 commit comments

Comments
 (0)