Skip to content

Commit e070af5

Browse files
authored
[Cherry-Pick] PR 3583 into release/v1.0 (#3598)
1 parent 0faab2e commit e070af5

3 files changed

Lines changed: 11 additions & 13 deletions

File tree

paddleformers/cli/train/sft/workflow.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,11 @@ def run_sft(
235235
if "DeepseekV3" in str(model_config.architectures):
236236
training_args.prediction_loss_only = True
237237

238+
if "qwen3_vl" in model_config.model_type and not model_args.lora:
239+
if training_args.sequence_parallel:
240+
logger.warning("Qwen3VL model do not support `sequence_parallel` yet, temporarily set to False")
241+
training_args.sequence_parallel = False
242+
238243
LlmMetaConfig.set_llm_config(model_config, training_args)
239244
model_config.use_fast_layer_norm = model_args.use_fast_layer_norm
240245

paddleformers/transformers/qwen3_vl/modeling_fleet.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ def _forward_impl(
228228
packed_seq_params=packed_seq_params,
229229
)
230230
hidden_states = self._forward_mlp(hidden_states)
231-
if self.layer_number in range(len(deepstack_visual_emb)):
231+
if deepstack_visual_emb and self.layer_number in range(len(deepstack_visual_emb)):
232232
# print("process _deepstack_process ",hidden_states.shape,visual_pos_masks.shape,deepstack_visual_emb[self.layer_number].shape)
233233
hidden_states = self._deepstack_process(
234234
hidden_states=hidden_states,
@@ -339,6 +339,7 @@ class Qwen3VLTextProvider(GPTModelProvider):
339339
use_flash_attention: bool = True
340340
use_fused_linear_cross_entropy: bool = True
341341
high_precision_rope: bool = True
342+
moe_grouped_gemm: bool = True
342343

343344
n_shared_experts: int = 0
344345
transform_rules = {
@@ -1125,15 +1126,6 @@ def forward(
11251126
else:
11261127
if position_ids.shape == input_ids.shape:
11271128
position_ids = position_ids.expand(3, position_ids.shape[0], -1)
1128-
else:
1129-
batch_size, seq_length = input_ids.shape
1130-
position_ids = paddle.arange(seq_length)
1131-
position_ids = position_ids.view(1, 1, -1).expand(3, batch_size, -1)
1132-
if cache_position is not None:
1133-
delta = cache_position[0] + self.rope_deltas
1134-
else:
1135-
delta = paddle.zeros((batch_size, seq_length))
1136-
position_ids = position_ids + delta
11371129

11381130
input_dict = {
11391131
"input_ids": input_ids,

paddleformers/transformers/qwen3_vl_moe/modeling.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
from ..model_outputs import BaseModelOutputWithPast, ModelOutput
4747
from ..model_utils import PretrainedModel
4848
from ..modeling_rope_utils import ROPE_INIT_FUNCTIONS
49-
from ..qwen3_vl.modeling_fleet import Qwen3VLModel, Qwen3VLProvider
49+
from ..qwen3_vl.modeling_fleet import Qwen3VLModelDist, Qwen3VLProvider
5050
from ..utils import logger
5151
from .configuration import (
5252
Qwen3VLMoeConfig,
@@ -376,7 +376,7 @@ def _gen_aoa_config(cls, config: Qwen3VLMoeConfig):
376376
else:
377377
split_experts_up_gate = ""
378378
split_experts_down = ""
379-
for expert_id in range(config.text_config.n_routed_experts):
379+
for expert_id in range(config.text_config.num_experts):
380380
split_experts_up_gate += f"{llm_prefix}{layer_id + 1}.mlp.experts.{expert_id}.up_gate_proj.weight,"
381381
split_experts_down += f"{llm_prefix}{layer_id + 1}.mlp.experts.{expert_id}.down_proj.weight,"
382382
split_experts_down += "axis=0"
@@ -2594,12 +2594,13 @@ def __new__(cls, config, have_criterion=True):
25942594
config.pipeline_model_parallel_size = max(config.pipeline_model_parallel_size, 1)
25952595
config.virtual_pipeline_model_parallel_size = max(config.virtual_pipeline_model_parallel_size, 1)
25962596
config.expert_model_parallel_size = max(config.expert_model_parallel_size, 1)
2597+
config.moe_grouped_gemm = True
25972598
criterion = None
25982599
if have_criterion:
25992600
criterion = CriterionLayer(config.text_config)
26002601
model_provider_class = Qwen3VLProvider
26012602
model_provider = model_provider_class.from_config(config)
2602-
qwen3vl_model = Qwen3VLModel(model_provider, model_version=config.model_type, criterion=criterion)
2603+
qwen3vl_model = Qwen3VLModelDist(model_provider, model_version=config.model_type, criterion=criterion)
26032604
qwen3vl_model._gen_aoa_config = cls._gen_aoa_config
26042605
qwen3vl_model._gen_inv_aoa_config = cls._gen_inv_aoa_config
26052606
qwen3vl_model._get_tensor_parallel_mappings = cls._get_tensor_parallel_mappings

0 commit comments

Comments
 (0)