|
46 | 46 | from ..model_outputs import BaseModelOutputWithPast, ModelOutput |
47 | 47 | from ..model_utils import PretrainedModel |
48 | 48 | from ..modeling_rope_utils import ROPE_INIT_FUNCTIONS |
49 | | -from ..qwen3_vl.modeling_fleet import Qwen3VLModel, Qwen3VLProvider |
| 49 | +from ..qwen3_vl.modeling_fleet import Qwen3VLModelDist, Qwen3VLProvider |
50 | 50 | from ..utils import logger |
51 | 51 | from .configuration import ( |
52 | 52 | Qwen3VLMoeConfig, |
@@ -376,7 +376,7 @@ def _gen_aoa_config(cls, config: Qwen3VLMoeConfig): |
376 | 376 | else: |
377 | 377 | split_experts_up_gate = "" |
378 | 378 | split_experts_down = "" |
379 | | - for expert_id in range(config.text_config.n_routed_experts): |
| 379 | + for expert_id in range(config.text_config.num_experts): |
380 | 380 | split_experts_up_gate += f"{llm_prefix}{layer_id + 1}.mlp.experts.{expert_id}.up_gate_proj.weight," |
381 | 381 | split_experts_down += f"{llm_prefix}{layer_id + 1}.mlp.experts.{expert_id}.down_proj.weight," |
382 | 382 | split_experts_down += "axis=0" |
@@ -2594,12 +2594,13 @@ def __new__(cls, config, have_criterion=True): |
2594 | 2594 | config.pipeline_model_parallel_size = max(config.pipeline_model_parallel_size, 1) |
2595 | 2595 | config.virtual_pipeline_model_parallel_size = max(config.virtual_pipeline_model_parallel_size, 1) |
2596 | 2596 | config.expert_model_parallel_size = max(config.expert_model_parallel_size, 1) |
| 2597 | + config.moe_grouped_gemm = True |
2597 | 2598 | criterion = None |
2598 | 2599 | if have_criterion: |
2599 | 2600 | criterion = CriterionLayer(config.text_config) |
2600 | 2601 | model_provider_class = Qwen3VLProvider |
2601 | 2602 | model_provider = model_provider_class.from_config(config) |
2602 | | - qwen3vl_model = Qwen3VLModel(model_provider, model_version=config.model_type, criterion=criterion) |
| 2603 | + qwen3vl_model = Qwen3VLModelDist(model_provider, model_version=config.model_type, criterion=criterion) |
2603 | 2604 | qwen3vl_model._gen_aoa_config = cls._gen_aoa_config |
2604 | 2605 | qwen3vl_model._gen_inv_aoa_config = cls._gen_inv_aoa_config |
2605 | 2606 | qwen3vl_model._get_tensor_parallel_mappings = cls._get_tensor_parallel_mappings |
|
0 commit comments