Skip to content

Commit 5725403

Browse files
Migrate Decoder (Gemma3/Deepseek/Llama4) and utils to NNX
1 parent 56334f6 commit 5725403

8 files changed

Lines changed: 421 additions & 80 deletions

File tree

src/maxtext/checkpoint_conversion/to_maxtext.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -384,14 +384,22 @@ def _build_single_axis_stacked_tensor(
384384
The final, assembled NumPy array for the MaxText parameter.
385385
"""
386386
tensors_to_stack = []
387+
# Heuristic to determine if we are stacking layers or experts.
388+
# If the number of items to stack equals the number of layers, it's a standard
389+
# scanned layer, and we use the configured param_scan_axis. Otherwise, it's
390+
# an unscanned MoE layer, and we stack along the expert axis (0).
391+
"""
392+
axis_to_stack = config.param_scan_axis if len(hf_source_keys) == config.base_num_decoder_layers else 0
393+
"""
387394

388-
if config.scan_layers:
389-
# If it's a standard scanned layer, we use the configured param_scan_axis.
390-
axis_to_stack = config.param_scan_axis
395+
# Workaround to load the HF model due to mismatched tensor ordering
396+
if len(hf_source_keys) == config.base_num_decoder_layers:
397+
if getattr(config, "enable_nnx", False):
398+
axis_to_stack = 0
399+
else:
400+
axis_to_stack = config.param_scan_axis
391401
else:
392-
# Otherwise, if an unscanned MoE layer, and we stack along the expert axis (0).
393402
axis_to_stack = 0
394-
395403
# The hook function needs the shape of an individual slice, not the full stacked tensor.
396404
# We calculate it by removing the stacking dimension from the final target shape.
397405
mt_slice_shape_list = list(target_shape)

src/maxtext/configs/base.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -706,7 +706,7 @@ autoregressive_decode_assert: ""
706706

707707
# For nsys profiler, pass the training command to nsys command
708708
# e.g. nsys profile -s none --force-overwrite true --capture-range=cudaProfilerApi --capture-range-end=stop {training command}
709-
profiler: "" # Supported profiler: '', xplane, nsys
709+
profiler: "xplane" # Supported profiler: '', xplane, nsys
710710
# If set to true, upload all profiler results from all hosts. Otherwise, only upload the profiler result from the first host.
711711
upload_all_profiler_results: False
712712
# Skip first n steps for profiling, to omit things like compilation and to give
@@ -1060,8 +1060,8 @@ position_id_per_seconds: 25
10601060
subslice_shape: ""
10611061

10621062
# NNX
1063-
enable_nnx: false
1064-
pure_nnx_decoder: false
1063+
enable_nnx: True
1064+
pure_nnx_decoder: True
10651065

10661066
################################## Qwen3-Next Specific Configs ##################################
10671067
# Kernel size for the 1D convolution in the Gated Delta Net

src/maxtext/layers/multi_token_prediction.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,6 @@ def __init__(
136136
model_mode=MODEL_MODE_TRAIN,
137137
)
138138

139-
140139
@property
141140
def embedding_norm(self):
142141
return getattr(self, f"mtp_{self.layer_number}_embedding_norm")
@@ -169,7 +168,6 @@ def transformer_layer(self):
169168
def transformer_layer(self, module):
170169
setattr(self, f"mtp_{self.layer_number}_transformer_layer", module)
171170

172-
173171
def __call__(
174172
self,
175173
prev_hidden_state: jnp.ndarray,

0 commit comments

Comments
 (0)