Skip to content

Commit 02c6342

Browse files
xibinliuecnal-cienet
authored andcommitted
NNX migration preparation: pure_nnx flag and init_state_fn
- pure_nnx: a flag to to choose pure NNX logic when NNX and linen models co-exist. - init_state_fn: a function to initialize the model state for the training. It will be set to different function for NNX and Linen.
1 parent fa4a13d commit 02c6342

27 files changed

Lines changed: 560 additions & 124 deletions

src/maxtext/checkpoint_conversion/standalone_scripts/convert_gpt3_ckpt_from_paxml.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
"""
3636

3737
import argparse
38+
import functools
3839
import gc
3940
import os
4041
import sys
@@ -87,7 +88,10 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name
8788
mesh = Mesh(devices_array, cfg.mesh_axes)
8889

8990
quant = quantizations.configure_quantization(cfg)
90-
model = transformer_as_linen(cfg, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN)
91+
if cfg.pure_nnx:
92+
raise NotImplementedError("Pure NNX support has not been implemented yet.")
93+
else:
94+
model = transformer_as_linen(cfg, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN)
9195
learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(cfg)
9296
tx = optimizers.get_optimizer(cfg, learning_rate_schedule)
9397

@@ -98,7 +102,12 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name
98102
cfg.checkpoint_period,
99103
)
100104

101-
state, _, _, _ = maxtext_utils.setup_training_state(model, None, tx, cfg, init_rng, mesh, checkpoint_manager)
105+
if cfg.pure_nnx:
106+
# NNX has a different function to init the training state.
107+
raise NotImplementedError("Pure NNX support has not been implemented yet.")
108+
else:
109+
init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, cfg, True, init_rng)
110+
state, _, _, _ = maxtext_utils.setup_training_state(None, cfg, mesh, checkpoint_manager, init_state_fn)
102111
max_logging.log("start")
103112
max_utils.print_mem_stats("After params initialized")
104113

src/maxtext/common/gcloud_stub.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,15 @@ def is_decoupled() -> bool: # dynamic check so setting env after initial import
4343
return os.environ.get("DECOUPLE_GCLOUD", "").upper() == "TRUE"
4444

4545

46+
def is_pure_nnx() -> bool: # dynamic check so setting env after initial import still works
47+
"""Return True when running in pure NNX mode (PURE_NNX=TRUE env var).
48+
49+
Defaults to FALSE — Linen is the default test mode.
50+
Set PURE_NNX=TRUE to opt in to NNX mode (skips linen_only tests, runs nnx_only tests).
51+
"""
52+
return os.environ.get("PURE_NNX", "FALSE").upper() == "TRUE"
53+
54+
4655
T = TypeVar("T")
4756

4857

src/maxtext/configs/base.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -514,6 +514,7 @@ logical_axis_rules: [
514514
['paged_kv_head_dim_size', []],
515515
['dense_layers', []],
516516
['moe_layers', []],
517+
['num_activations', []],
517518
['engram_dim', ['tensor']],
518519
['mhc', []],
519520
['diloco', 'diloco'],
@@ -1088,6 +1089,7 @@ subslice_shape: ""
10881089
# NNX
10891090
enable_nnx: True
10901091
pure_nnx_decoder: True
1092+
pure_nnx: True
10911093

10921094
################################## Qwen3-Next Specific Configs ##################################
10931095
# Kernel size for the 1D convolution in the Gated Delta Net

src/maxtext/configs/decoupled_base_test.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ eval_dataset_name: 'c4/en:3.1.0'
3030
# Use dot_product attention to avoid GPU Pallas shared memory limits on AMD GPUs
3131
attention: "dot_product"
3232

33+
# Default to Linen mode for tests; NNX is opt-in via PURE_NNX=TRUE.
34+
pure_nnx: False
35+
pure_nnx_decoder: False
36+
3337
# Avoid HLO dump overhead.
3438
dump_hlo: false
3539
jax_cache_dir: ""

src/maxtext/configs/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -784,6 +784,7 @@ class HardwareAndMesh(BaseModel):
784784
optimize_mesh_for_tpu_v6e: bool = Field(False, description="Apply transformations to the mesh for TPU v6e.")
785785
shardy: bool = Field(True, description="Whether to use shardy XLA backend.")
786786
pure_nnx_decoder: bool = Field(False, description="Whether to enable pure NNX decoder.")
787+
pure_nnx: bool = Field(False, description="Whether to enable pure NNX mode.")
787788

788789

789790
class LayoutAndSharding(BaseModel):

src/maxtext/experimental/rl/grpo_trainer.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -546,23 +546,43 @@ def setup_train_loop(
546546
max_logging.log("Training mesh used for the workload")
547547
num_inference_devices = config.inference_devices_per_replica * config.inference_replicas
548548
training_devices = jax.devices()[num_inference_devices:]
549-
model = mt.from_config(config, devices=training_devices)
549+
if config.pure_nnx:
550+
raise NotImplementedError("Pure NNX support has not been implemented yet.")
551+
else:
552+
model = mt.from_config(config, devices=training_devices)
550553
mesh = model.mesh
551554
max_logging.log("Inference mesh used for the workload")
552555
inference_devices = jax.devices()[:num_inference_devices]
553-
inference_model = mt.from_config(config_inference, devices=inference_devices)
556+
if config_inference.pure_nnx:
557+
raise NotImplementedError("Pure NNX support has not been implemented yet.")
558+
else:
559+
inference_model = mt.from_config(config_inference, devices=inference_devices)
554560
inference_mesh = inference_model.mesh
555-
init_rng, checkpoint_manager, learning_rate_schedule, tx = train_utils.create_training_tools(config, model, mesh)
561+
init_rng = jax.random.PRNGKey(config.init_weights_seed)
562+
learning_rate_schedule, tx = train_utils.create_training_optimizer(config, model)
563+
if config.pure_nnx:
564+
# NNX has a different function to init the training state.
565+
raise NotImplementedError("Pure NNX support has not been implemented yet.")
566+
else:
567+
init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, config, True, init_rng)
568+
checkpoint_manager = train_utils.create_checkpoint_manager(config, mesh, init_state_fn)
556569

557570
with maybe_record_goodput(recorder, GoodputEvent.TRAINING_PREPARATION):
558571
data_iterator = grpo_input_pipeline.create_data_iterator(config_inference, inference_mesh)
559572
state, _, state_mesh_shardings, data_iterator = maxtext_utils.setup_training_state(
560-
model, data_iterator, tx, config, init_rng, mesh, checkpoint_manager
573+
data_iterator, config, mesh, checkpoint_manager, init_state_fn
561574
)
562575

563576
# create inference_state_mesh_shardings from inference_mesh
577+
if config_inference.pure_nnx:
578+
# NNX has a different function to init the training state.
579+
raise NotImplementedError("Pure NNX support has not been implemented yet.")
580+
else:
581+
init_inference_state_fn = functools.partial(
582+
maxtext_utils.init_initial_state, inference_model, tx, config_inference, False, init_rng
583+
)
564584
inference_state_mesh_shardings = maxtext_utils.get_abstract_state(
565-
inference_model, tx, config_inference, init_rng, inference_mesh, is_training=False
585+
config_inference, inference_mesh, init_inference_state_fn, is_training=False
566586
)[2]
567587
if not config.using_pipeline_parallelism:
568588
# The vocab tensor(s) of shape [vocab, embed] (and transpose) are not sharded by stage

src/maxtext/inference/maxengine/maxengine.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,10 @@ def __init__(self, config: Any, devices: Any | None = None):
113113

114114
# Model and Optimizer definition
115115
quant = quantizations.configure_quantization(config)
116-
self.model = models.transformer_as_linen(config, mesh=self._mesh, quant=quant, model_mode=MODEL_MODE_PREFILL)
116+
if config.pure_nnx:
117+
raise NotImplementedError("Pure NNX support has not been implemented yet.")
118+
else:
119+
self.model = models.transformer_as_linen(config, mesh=self._mesh, quant=quant, model_mode=MODEL_MODE_PREFILL)
117120
self.replicated_sharding = jax.sharding.NamedSharding(self._mesh, P(None))
118121

119122
self.abstract_params = None
@@ -229,17 +232,25 @@ def load_params(self, *args, params=None, rng: PRNGKeyType | None = None, **kwar
229232
rng1, rng2, rng3 = jax.random.split(rng, 3)
230233
if params:
231234
print("Resharding given params")
235+
if self.config.pure_nnx:
236+
# NNX has a different function to init the training state.
237+
raise NotImplementedError("Pure NNX support has not been implemented yet.")
238+
else:
239+
init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, None, self.config, False, rng)
232240
_, self.state_mesh_annotations, state_mesh_shardings = maxtext_utils.get_abstract_state(
233-
self.model, None, self.config, rng, self._mesh, False
241+
self.config, self._mesh, init_state_fn, False
234242
)
235243
# reshard given params based on shardings from config in MaxEngine
236244
params = jax.device_put(params, state_mesh_shardings.params)
237245
state = maxtext_utils.init_decode_state(None, params)
238246
state = max_utils.unbox_logicallypartioned(state)
239247
else:
240-
state, self.state_mesh_annotations = maxtext_utils.setup_decode_state(
241-
self.model, self.config, rng1, self._mesh, None
242-
)
248+
if self.config.pure_nnx:
249+
# NNX has a different function to init the training state.
250+
raise NotImplementedError("Pure NNX support has not been implemented yet.")
251+
else:
252+
init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, None, self.config, False, rng1)
253+
state, self.state_mesh_annotations = maxtext_utils.setup_decode_state(self.config, self._mesh, None, init_state_fn)
243254
# pylint: disable=isinstance-second-argument-not-valid-type
244255
self.abstract_params = jax.tree_util.tree_map(
245256
lambda x: jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding)

src/maxtext/trainers/pre_train/train_compile.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@ def get_topology_mesh(config):
9191

9292
def get_shaped_inputs(topology_mesh, config):
9393
"""Get shaped abstractions of inputs to train_step: state, batch and rng"""
94+
if config.pure_nnx:
95+
raise NotImplementedError("pure_nnx AOT compilation support not yet implemented.")
9496
# Construct the model and optimizer to get shaped versions of the state
9597
quant = quantizations.configure_quantization(config)
9698
model = Transformer(config, topology_mesh, quant=quant, model_mode=MODEL_MODE_TRAIN)
@@ -103,13 +105,13 @@ def get_shaped_inputs(topology_mesh, config):
103105
_, example_rng = jax.random.split(jax.random.PRNGKey(0), 2)
104106
shaped_rng = jax.ShapeDtypeStruct(example_rng.shape, example_rng.dtype)
105107

108+
init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, config, True, example_rng)
109+
106110
# Shaped state
107-
abstract_state, _, state_mesh_shardings = maxtext_utils.get_abstract_state(
108-
model, tx, config, example_rng, topology_mesh
109-
)
111+
abstract_state, _, state_mesh_shardings = maxtext_utils.get_abstract_state(config, topology_mesh, init_state_fn, True)
110112

111113
# unsharded logical annotations
112-
logical_annotations = maxtext_utils.get_logical_annotations(model, tx, config, example_rng, topology_mesh)
114+
logical_annotations = maxtext_utils.get_logical_annotations(config, topology_mesh, init_state_fn)
113115

114116
# Shaped batch
115117
shaped_batch = maxtext_utils.get_shaped_batch(config)

src/maxtext/utils/generate_param_only_checkpoint.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
The output "parameter state" is output to the checkpoint directory. Additionally it is cast down to bf16.
2323
"""
2424

25+
import functools
2526
import os.path
2627
from typing import Sequence
2728

@@ -42,8 +43,6 @@
4243
from maxtext.utils import max_utils
4344
from maxtext.utils import maxtext_utils
4445

45-
Transformer = models.transformer_as_linen
46-
4746

4847
def _possibly_unroll_params(config, training_state, training_state_annotations, mesh):
4948
"""Unroll scanned input layers when force_unroll is set."""
@@ -93,12 +92,20 @@ def _read_train_checkpoint(config, checkpoint_manager, mesh):
9392
"""Read training checkpoint at path defined by load_full_state_path."""
9493
# Model and Optimizer definition
9594
quant = quantizations.configure_quantization(config)
96-
model = Transformer(config, mesh, quant, MODEL_MODE_TRAIN)
95+
if config.pure_nnx:
96+
raise NotImplementedError("Pure NNX support has not been implemented yet.")
97+
else:
98+
model = models.transformer_as_linen(config, mesh, quant, MODEL_MODE_TRAIN)
9799
rng = random.PRNGKey(0)
98100
learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config)
99101
tx = optimizers.get_optimizer(config, learning_rate_schedule)
102+
if config.pure_nnx:
103+
# NNX has a different function to init the training state.
104+
raise NotImplementedError("Pure NNX support has not been implemented yet.")
105+
else:
106+
init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, config, True, rng)
100107
state, state_mesh_notations, _, _ = maxtext_utils.setup_training_state(
101-
model, None, tx, config, rng, mesh, checkpoint_manager
108+
None, config, mesh, checkpoint_manager, init_state_fn
102109
)
103110
num_params = max_utils.calculate_num_params_from_pytree(state.params)
104111
max_logging.log(f"In input checkpoint Number of model params={num_params/1e9:.3f} billion")
@@ -109,7 +116,10 @@ def _generate_lora_decode_checkpoints(config, mesh):
109116
"""Read lora checkpoints checkpoint at path defined by load_full_state_path."""
110117
# Model and Optimizer definition
111118
quant = quantizations.configure_quantization(config)
112-
model = Transformer(config, mesh, quant, MODEL_MODE_TRAIN)
119+
if config.pure_nnx:
120+
raise NotImplementedError("Pure NNX support has not been implemented yet.")
121+
else:
122+
model = models.transformer_as_linen(config, mesh, quant, MODEL_MODE_TRAIN)
113123
rng = random.PRNGKey(0)
114124
learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config)
115125
tx = optimizers.get_optimizer(config, learning_rate_schedule)

src/maxtext/utils/layerwise_quantization.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
3131
"""
3232

33+
import functools
3334
import os
3435
from typing import Any, Sequence
3536

@@ -174,12 +175,19 @@ def __init__(self, config: Any, rng: PRNGKeyType):
174175

175176
# Model and quantization config
176177
self.quant = quantizations.configure_quantization(config)
177-
model = models.transformer_as_linen(
178-
config, mesh=self._mesh, quant=self.quant, model_mode=common_types.MODEL_MODE_TRAIN
179-
)
180-
self.unboxed_abstract_state, _, _ = maxtext_utils.get_abstract_state(
181-
model, None, self.config, self.rng, self._mesh, False
182-
)
178+
if self.config.pure_nnx:
179+
raise NotImplementedError("Pure NNX support has not been implemented yet.")
180+
else:
181+
model = models.transformer_as_linen(
182+
config, mesh=self._mesh, quant=self.quant, model_mode=common_types.MODEL_MODE_TRAIN
183+
)
184+
if self.config.pure_nnx:
185+
# NNX has a different function to init the training state.
186+
raise NotImplementedError("Pure NNX support has not been implemented yet.")
187+
else:
188+
init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, None, self.config, False, self.rng)
189+
190+
self.unboxed_abstract_state, _, _ = maxtext_utils.get_abstract_state(self.config, self._mesh, init_state_fn, False)
183191

184192
def load_and_quantize(self) -> None:
185193
"""

0 commit comments

Comments
 (0)