Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions src/maxtext/checkpoint_conversion/utils/hf_model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,22 @@
torch_dtype="bfloat16",
)

qwen3_1_7b_config = transformers.Qwen3Config(
vocab_size=151936,
hidden_size=2048,
intermediate_size=6144,
num_hidden_layers=28,
num_attention_heads=16,
num_key_value_heads=8,
head_dim=128,
hidden_act="silu",
max_position_embeddings=40960,
rms_norm_eps=1.0e-6,
rope_theta=1000000.0,
tie_word_embeddings=True,
torch_dtype="bfloat16",
)

qwen3_4b_config = transformers.Qwen3Config(
vocab_size=151936,
hidden_size=2560,
Expand Down Expand Up @@ -816,16 +832,22 @@
"gemma3-12b": gemma3_12b_config,
"gemma3-27b": gemma3_27b_config,
"qwen3-0.6b": qwen3_0_6b_config,
"qwen3-1.7b": qwen3_1_7b_config,
"qwen3-1.7b-base": qwen3_1_7b_config,
"qwen3-4b": qwen3_4b_config,
"qwen3-4b-base": qwen3_4b_config,
"qwen3-4b-thinking-2507": qwen3_4b_config,
"qwen3-8b": qwen3_8b_config,
"qwen3-8b-base": qwen3_8b_config,
"qwen3-14b": qwen3_14b_config,
"qwen3-14b-base": qwen3_14b_config,
"qwen3-32b": qwen3_32b_config,
"llama3.1-8b": llama31_8b_config,
"llama3.1-8b-Instruct": llama31_8b_config,
"llama3.1-70b": llama31_70b_config,
"llama3.1-405b": llama31_405b_config,
"qwen3-30b-a3b": qwen3_30b_a3b_thinking_2507_config,
"qwen3-30b-a3b-base": qwen3_30b_a3b_thinking_2507_config,
"qwen3-235b-a22b": qwen3_235b_a22b_thinking_2507_config,
"qwen3-480b-a35b": qwen3_coder_480b_a35b_config,
"deepseek3-671b": deepseek3_671b_config,
Expand Down
12 changes: 12 additions & 0 deletions src/maxtext/checkpoint_conversion/utils/param_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -2333,15 +2333,21 @@ def pad_hf_embedding_layer(input_tensor, target_shape):
"gemma3-12b": GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING,
"gemma3-27b": GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING,
"qwen3-0.6b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING,
"qwen3-1.7b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING,
"qwen3-1.7b-base": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING,
"qwen3-4b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING,
"qwen3-4b-base": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING,
"qwen3-4b-thinking-2507": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING,
"qwen3-8b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING,
"qwen3-8b-base": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING,
"qwen3-14b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING,
"qwen3-14b-base": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING,
"qwen3-32b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING,
"llama3.1-8b": LLAMA31_MAXTEXT_TO_HF_PARAM_MAPPING,
"llama3.1-70b": LLAMA31_MAXTEXT_TO_HF_PARAM_MAPPING,
"llama3.1-405b": LLAMA31_MAXTEXT_TO_HF_PARAM_MAPPING,
"qwen3-30b-a3b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING,
"qwen3-30b-a3b-base": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING,
"qwen3-235b-a22b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING,
"qwen3-coder-480b-a35b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING,
"deepseek3-671b": DEEPSEEK_MAXTEXT_TO_HF_PARAM_MAPPING,
Expand All @@ -2365,15 +2371,21 @@ def pad_hf_embedding_layer(input_tensor, target_shape):
"gemma3-12b": GEMMA3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"gemma3-27b": GEMMA3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"qwen3-0.6b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"qwen3-1.7b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"qwen3-1.7b-base": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"qwen3-4b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"qwen3-4b-base": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"qwen3-4b-thinking-2507": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"qwen3-8b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"qwen3-8b-base": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"qwen3-14b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"qwen3-14b-base": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"qwen3-32b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"llama3.1-8b": LLAMA31_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"llama3.1-70b": LLAMA31_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"llama3.1-405b": LLAMA31_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"qwen3-30b-a3b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"qwen3-30b-a3b-base": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"qwen3-235b-a22b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"qwen3-coder-480b-a35b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"deepseek3-671b": DEEPSEEK_MAXTEXT_TO_HF_PARAM_HOOK_FN,
Expand Down
37 changes: 37 additions & 0 deletions src/maxtext/configs/models/qwen3-1.7b-base.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright 2023–2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# model config for qwen3-1.7b-base

base_emb_dim: 2048
base_num_query_heads: 16
base_num_kv_heads: 8
base_mlp_dim: 6144
base_num_decoder_layers: 28
head_dim: 128
mlp_activations: ["silu", "linear"] # "hidden_act": "silu" implies SwiGLU
vocab_size: 151936

decoder_block: "qwen3"

normalization_layer_epsilon: 1.0e-6
rope_max_timescale: 1000000

use_qk_norm: True

logits_via_embedding: True # from "tie_word_embeddings": true
normalize_embedding_logits: False
enable_dropout: False # deterministic for testing

tokenizer_type: "huggingface"
37 changes: 37 additions & 0 deletions src/maxtext/configs/models/qwen3-1.7b.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright 2023–2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# model config for qwen3-1.7b

base_emb_dim: 2048
base_num_query_heads: 16
base_num_kv_heads: 8
base_mlp_dim: 6144
base_num_decoder_layers: 28
head_dim: 128
mlp_activations: ["silu", "linear"] # "hidden_act": "silu" implies SwiGLU
vocab_size: 151936

decoder_block: "qwen3"

normalization_layer_epsilon: 1.0e-6
rope_max_timescale: 1000000

use_qk_norm: True

logits_via_embedding: True # from "tie_word_embeddings": true
normalize_embedding_logits: False
enable_dropout: False # deterministic for testing

tokenizer_type: "huggingface"
37 changes: 37 additions & 0 deletions src/maxtext/configs/models/qwen3-14b-base.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright 2023–2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# model config for qwen3-14b-base

base_emb_dim: 5120
base_num_query_heads: 40
base_num_kv_heads: 8
base_mlp_dim: 17408
base_num_decoder_layers: 40
head_dim: 128
mlp_activations: ["silu", "linear"] # "hidden_act": "silu" implies SwiGLU
vocab_size: 151936

decoder_block: "qwen3"

normalization_layer_epsilon: 1.0e-6
rope_max_timescale: 1000000

use_qk_norm: True

logits_via_embedding: False # different from 0.6 and 4B variants, "tie_word_embeddings": false
normalize_embedding_logits: False

tokenizer_type: "huggingface"

40 changes: 40 additions & 0 deletions src/maxtext/configs/models/qwen3-30b-a3b-base.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Model config for Qwen3-30B-A3B-base

# Core Architectural Parameters
decoder_block: "qwen3_moe"
base_emb_dim: 2048
base_mlp_dim: 768
base_num_query_heads: 32
base_num_kv_heads: 4
base_num_decoder_layers: 48
head_dim: 128
mlp_activations: ["silu", "linear"]
vocab_size: 151936
normalization_layer_epsilon: 1.0e-6
use_qk_norm: True

# MoE Specific Parameters
num_experts: 128
num_experts_per_tok: 8
base_moe_mlp_dim: 768
norm_topk_prob: true

# RoPE Settings
rope_max_timescale: 10_000_000

# General Model Settings
enable_dropout: False
37 changes: 37 additions & 0 deletions src/maxtext/configs/models/qwen3-4b-base.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright 2023–2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# model config for qwen3-4b-base

base_emb_dim: 2560
base_num_query_heads: 32
base_num_kv_heads: 8
base_mlp_dim: 9728
base_num_decoder_layers: 36
head_dim: 128
mlp_activations: ["silu", "linear"] # "hidden_act": "silu" implies SwiGLU
vocab_size: 151936

decoder_block: "qwen3"

normalization_layer_epsilon: 1.0e-6
rope_max_timescale: 1000000

use_qk_norm: True

logits_via_embedding: True # from "tie_word_embeddings": true
normalize_embedding_logits: False
enable_dropout: False # deterministic for testing

tokenizer_type: "huggingface"
38 changes: 38 additions & 0 deletions src/maxtext/configs/models/qwen3-8b-base.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright 2023–2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# model config for qwen3-8b-base

base_emb_dim: 4096
base_num_query_heads: 32
base_num_kv_heads: 8
base_mlp_dim: 12288
base_num_decoder_layers: 36
head_dim: 128
mlp_activations: ["silu", "linear"] # "hidden_act": "silu" implies SwiGLU
vocab_size: 151936

decoder_block: "qwen3"

normalization_layer_epsilon: 1.0e-6
rope_max_timescale: 1000000

use_qk_norm: True

logits_via_embedding: False # different from smaller variants, "tie_word_embeddings": false
normalize_embedding_logits: False
enable_dropout: False # deterministic for testing

tokenizer_type: "huggingface"

12 changes: 12 additions & 0 deletions src/maxtext/configs/post_train/rl.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,17 @@ rl:
grpo_epsilon: 0.2
loss_algo: 'grpo' # grpo or gspo-token

# ====== LoRA ======
# Low-Rank Adaptation for the actor model. When enabled, only the LoRA parameters
# are trained and checkpointed, significantly reducing memory and compute.
lora:
enabled: False # this is still WIP
rank: 32
alpha: 64.0
# Regex matching module paths to apply LoRA to.
# Qwix uses re.fullmatch over slash-separated module paths like:
# layers/self_attention/query, layers/self_attention/key, ...
module_path: 'layers/(self_attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo))'

# ====== Models ======
# for MaxText
Expand Down Expand Up @@ -163,6 +174,7 @@ max_num_checkpoints_to_keep: 10

# ====== Reward ======

reward_exact_answer: 5.0
reward_exact_format_match: 3.0
reward_white_space_format_match: 1.5
reward_partial_format_match: 0.5
Expand Down
Loading
Loading