Skip to content

Commit 6c8c56f

Browse files
committed
[NNX] Delete Linen (4/4): remove the pure_nnx/enable_nnx/pure_nnx_decoder config flags
Remove the three flags from types.py, base.yml, inference/vllm.yml, pyconfig, and the post-train distillation configs. NNX is the only path; the flags no longer exist.
1 parent 83bf2f1 commit 6c8c56f

7 files changed

Lines changed: 2 additions & 20 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1171,11 +1171,6 @@ position_id_per_seconds: 25
11711171
# Example: "8,8" to use a 8x8 subgrid (64 chips) of a full pod (16x16) of trillium.
11721172
subslice_shape: ""
11731173

1174-
# NNX
1175-
enable_nnx: true
1176-
pure_nnx_decoder: true
1177-
pure_nnx: true
1178-
11791174
################################## Qwen3-Next Specific Configs ##################################
11801175
# Kernel size for the 1D convolution in the Gated Delta Net
11811176
gdn_conv_kernel_dim: 4

src/maxtext/configs/inference/vllm.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@ base_config: "base.yml"
1616
attention: "vllm_rpa"
1717
model_call_mode: "inference"
1818

19-
# NNX required for vLLM integration
20-
enable_nnx: true
2119
# Avoid re-initializing JAX distributed system when using vLLM
2220
skip_jax_distributed_system: true
2321
# Scanned layers are not supported with vLLM integration

src/maxtext/configs/post_train/distillation_gpt_oss_20b.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ distill_alpha: 0.5
1919
distill_temperature: 1.0
2020
distill_beta: 0
2121
distill_layer_indices: []
22-
enable_nnx: True
2322
load_balance_loss_weight: 0.001
2423

2524
ici_fsdp_parallelism: 32

src/maxtext/configs/post_train/distillation_qwen3_30b_base.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ distill_alpha: 0.6
2121
distill_temperature: 1.0
2222
distill_beta: 1.0
2323
distill_layer_indices: [0,1,2,3,4,5,6,7]
24-
enable_nnx: True
2524
load_balance_loss_weight: 0.001
2625

2726
ici_fsdp_parallelism: -1

src/maxtext/configs/post_train/distillation_qwen3_30b_base_pdbs8.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ distill_alpha: 0.6
2121
distill_temperature: 1.0
2222
distill_beta: 1.0
2323
distill_layer_indices: [0,1,2,3,4,5,6,7]
24-
enable_nnx: True
2524
load_balance_loss_weight: 0.001
2625

2726
ici_fsdp_parallelism: -1

src/maxtext/configs/pyconfig_deprecated.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -193,8 +193,7 @@ def validate_expert_shard_attention_option(expert_shard_attention_option: str) -
193193
)
194194

195195

196-
def validate_vocab_tiling(num_vocab_tiling: int, per_device_batch_size: int, max_target_length: int, enable_nnx: bool):
197-
del enable_nnx # NNX vocab tiling supported via vocab_tiling_nnx_loss in vocabulary_tiling.py
196+
def validate_vocab_tiling(num_vocab_tiling: int, per_device_batch_size: int, max_target_length: int):
198197
if (per_device_batch_size * max_target_length) % num_vocab_tiling != 0:
199198
raise ValueError("Per device batch size times sequence length should be divisible by the number of vocab tiles.")
200199

@@ -238,9 +237,7 @@ def validate_keys(keys):
238237
validate_model_call_mode(keys["model_call_mode"])
239238
validate_prefill_and_target_lengths(keys["max_prefill_predict_length"], keys["max_target_length"])
240239
validate_rope_type(keys["rope_type"])
241-
validate_vocab_tiling(
242-
keys["num_vocab_tiling"], keys["per_device_batch_size"], keys["max_target_length"], keys["enable_nnx"]
243-
)
240+
validate_vocab_tiling(keys["num_vocab_tiling"], keys["per_device_batch_size"], keys["max_target_length"])
244241
if keys["enable_rampup_batch_size"]:
245242
validate_rampup_batch_size(
246243
keys["per_device_batch_size_start"],

src/maxtext/configs/types.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -895,11 +895,8 @@ class HardwareAndMesh(BaseModel):
895895
CustomRule.DEFAULT, description="Customized mesh and logical rules for granularity."
896896
)
897897
allow_split_physical_axes: bool = Field(False, description="Allow splitting physical axes for device mesh creation.")
898-
enable_nnx: bool = Field(False, description="Whether to use NNX for model definition.")
899898
optimize_mesh_for_tpu_v6e: bool = Field(False, description="Apply transformations to the mesh for TPU v6e.")
900899
shardy: bool = Field(True, description="Whether to use shardy XLA backend.")
901-
pure_nnx_decoder: bool = Field(False, description="Whether to enable pure NNX decoder.")
902-
pure_nnx: bool = Field(False, description="Whether to enable pure NNX mode.")
903900
remove_size_one_mesh_axis_from_type: bool = Field(
904901
True, description="Whether to remove size one mesh axis from type through jax.config."
905902
)
@@ -2555,8 +2552,6 @@ def validate_and_set_hlo_dump_defaults():
25552552
if self.distill_beta > 0.0:
25562553
if not self.scan_layers:
25572554
raise ValueError("a value of self.distill_beta > 0.0 requires self.scan_layers = True")
2558-
if not self.enable_nnx:
2559-
raise ValueError("a value of self.distill_beta > 0.0 requires self.enable_nnx = True")
25602555

25612556
# Validate distillation schedule parameters
25622557
if self.distill_alpha_end is not None and not 0.0 <= self.distill_alpha_end <= 1.0:

0 commit comments

Comments
 (0)