Skip to content

Commit 602c49d

Browse files
committed
pyink
1 parent b45f2b6 commit 602c49d

3 files changed

Lines changed: 10 additions & 9 deletions

File tree

src/maxtext/checkpoint_conversion/utils/param_mapping.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1042,7 +1042,7 @@ def QWEN3_5_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=Fals
10421042
): f"model.language_model.layers.{i}.mlp.experts.gate_up_proj",
10431043
}
10441044
)
1045-
1045+
10461046
# Vision mapping for Qwen3.5
10471047
if maxtext_config.use_multimodal and "vision_config" in config:
10481048
vision_config = config["vision_config"]
@@ -1094,15 +1094,11 @@ def QWEN3_5_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=Fals
10941094
mapping["params-vision_encoder-Qwen3_5MoeVisionProjector_0-merger-mlp_0-kernel"] = (
10951095
"model.visual.merger.linear_fc1.weight"
10961096
)
1097-
mapping["params-vision_encoder-Qwen3_5MoeVisionProjector_0-merger-mlp_0-bias"] = (
1098-
"model.visual.merger.linear_fc1.bias"
1099-
)
1097+
mapping["params-vision_encoder-Qwen3_5MoeVisionProjector_0-merger-mlp_0-bias"] = "model.visual.merger.linear_fc1.bias"
11001098
mapping["params-vision_encoder-Qwen3_5MoeVisionProjector_0-merger-mlp_2-kernel"] = (
11011099
"model.visual.merger.linear_fc2.weight"
11021100
)
1103-
mapping["params-vision_encoder-Qwen3_5MoeVisionProjector_0-merger-mlp_2-bias"] = (
1104-
"model.visual.merger.linear_fc2.bias"
1105-
)
1101+
mapping["params-vision_encoder-Qwen3_5MoeVisionProjector_0-merger-mlp_2-bias"] = "model.visual.merger.linear_fc2.bias"
11061102

11071103
return mapping
11081104

src/maxtext/multimodal/processor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ def get_bidirectional_mask_vision(config, decoder_input_tokens, is_video: bool =
224224
bidirectional_mask_vision = decoder_input_tokens == LLAMA4_PATCH_TOKEN
225225
elif config.model_name in ["qwen3-omni-30b-a3b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b"]:
226226
from maxtext.multimodal.processor_qwen3_omni import QwenTokens # pylint: disable=import-outside-toplevel
227+
227228
tokens = QwenTokens(config)
228229

229230
if is_video:
@@ -238,8 +239,9 @@ def get_bidirectional_mask_audio(config, decoder_input_tokens):
238239
bidirectional_mask_audio = None
239240
if config.model_name in ["qwen3-omni-30b-a3b"]:
240241
from maxtext.multimodal.processor_qwen3_omni import QwenTokens # pylint: disable=import-outside-toplevel
242+
241243
tokens = QwenTokens(config)
242244

243245
# Create bidirectional_mask for audio token merging
244-
bidirectional_mask_audio = (decoder_input_tokens == tokens.audio_pad)
246+
bidirectional_mask_audio = decoder_input_tokens == tokens.audio_pad
245247
return bidirectional_mask_audio

src/maxtext/multimodal/processor_qwen3_omni.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,10 @@ def __init__(self, config=None):
9696
# If config is None, will fall back to default Qwen3-Omni tokens.
9797
self.model_name = getattr(config, "model_name", None) or self._DEFAULT_MODEL
9898
# Match by prefix (e.g. "qwen3.5" covers qwen3.5 family), fall back to default.
99-
token_config = next((v for k, v in QWEN_SPECIAL_TOKEN_CONFIGS.items() if self.model_name.startswith(k)), QWEN_SPECIAL_TOKEN_CONFIGS[self._DEFAULT_MODEL])
99+
token_config = next(
100+
(v for k, v in QWEN_SPECIAL_TOKEN_CONFIGS.items() if self.model_name.startswith(k)),
101+
QWEN_SPECIAL_TOKEN_CONFIGS[self._DEFAULT_MODEL],
102+
)
100103
self.__dict__.update(token_config)
101104

102105

0 commit comments

Comments
 (0)