diff --git a/src/prime_rl/inference/patches.py b/src/prime_rl/inference/patches.py index f72edc7926..eddbd80791 100644 --- a/src/prime_rl/inference/patches.py +++ b/src/prime_rl/inference/patches.py @@ -14,6 +14,7 @@ def transformers_v5_compat(): _patch_qwen35_lora() _patch_lora_key_prefix() + monkey_patch_fused_moe_dummy_lora() monkey_patch_dp_engine_core_pause_resume_deadlock() monkey_patch_vllm_layerwise_reload_alias_buffers() @@ -247,6 +248,125 @@ def check_unexpected_modules(modules: dict): LoRAModel.from_local_checkpoint = classmethod(_patched_from_local_checkpoint) +def monkey_patch_fused_moe_dummy_lora(): + """Fix vLLM dummy LoRA warmup for FusedMoEWithLoRA. + + In vLLM 0.21.0, ``LoRAModelManager.create_dummy_lora`` narrows the + replacement list for ``FusedMoEWithLoRA`` to the local EP experts, then + rebuilds that list from the original ``n_slices`` value. That can index past + ``module.lora_a_stacked`` during profiling before the server is ready. + """ + from vllm.lora.layers.base import BaseLayerWithLoRA + from vllm.lora.lora_model import LoRAModel + from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights + from vllm.lora.model_manager import LoRAModelManager + + def _patched_create_dummy_lora( + self: LoRAModelManager, + lora_id: int, + rank: int, + embedding_modules: dict[str, str] | None = None, + ) -> LoRAModel: + model = LoRAModel(lora_id, rank, {}) + for module_name, module in self.model.named_modules(): + if ( + not self._match_target_modules(module_name) + or not isinstance(module, BaseLayerWithLoRA) + or self._get_punica_wrapper(module_name) is None + ): + continue + + parts = module_name.split(".") + if module_name not in self.packed_modules: + assert embedding_modules is not None + if parts[-1] in embedding_modules: + if parts[-1] == "lm_head": + input_dim = module.lora_a_stacked[0].shape[-1] + output_dim = module.lora_b_stacked[0].shape[-2] + else: + input_dim = ( + module.base_layer.org_vocab_size + if hasattr(module.base_layer, "org_vocab_size") + else module.base_layer.weight.shape[1] + ) + output_dim = ( + module.base_layer.embedding_dim + if hasattr(module.base_layer, "embedding_dim") + else module.base_layer.weight.shape[0] + ) + lora = LoRALayerWeights.create_dummy_lora_weights( + module_name, + input_dim, + output_dim, + rank, + module.lora_a_stacked[0].dtype, + "cpu", + ) + model.loras[module_name] = lora + elif module.__class__.__name__ == "FusedMoE3DWithLoRA": + lora = LoRALayerWeights.create_dummy_lora_weights( + module_name, + module.w2_input_size, + module.w2_output_size, + rank * module.w2_lora_a_stacked[0].shape[1], + module.w2_lora_a_stacked[0].dtype, + "cpu", + ) + model.loras[module_name] = lora + lora = LoRALayerWeights.create_dummy_lora_weights( + module_name, + module.w13_input_size, + module.w13_output_size, + rank * module.w13_lora_a_stacked[0].shape[1], + module.w13_lora_a_stacked[0].dtype, + "cpu", + ) + model.loras[module_name + ".base_layer"] = lora + else: + lora = LoRALayerWeights.create_dummy_lora_weights( + module_name, + module.lora_a_stacked[0].shape[-1], + module.lora_b_stacked[0].shape[-2], + rank, + module.lora_a_stacked[0].dtype, + "cpu", + ) + model.loras[module_name] = lora + else: + replacements = self.packed_modules_mapping[parts[-1]] + n_slices = getattr(module, "n_slices", len(replacements)) + if module.__class__.__name__ == "FusedMoEWithLoRA": + replacements = replacements[: len(module.lora_a_stacked) // self.lora_slots] + n_slices = len(replacements) + + subloras: list[LoRALayerWeights | None] = [] + if n_slices != len(replacements): + replacements = [f"slice_{i}" for i in range(n_slices)] + for i, replacement in enumerate(replacements): + lora = LoRALayerWeights.create_dummy_lora_weights( + module_name + "." + replacement, + module.lora_a_stacked[i].shape[-1], + module.lora_b_stacked[i].shape[-2], + rank, + module.lora_a_stacked[i].dtype, + "cpu", + ) + subloras.append(lora) + + if module.__class__.__name__ == "FusedMoEWithLoRA": + if self._is_non_gated_moe and len(subloras) > 0: + subloras = self._pad_lora_pairs_to_triplets(subloras) + lora = PackedLoRALayerWeights.pack_moe( + subloras, module_name, is_non_gated_moe=self._is_non_gated_moe + ) + else: + lora = PackedLoRALayerWeights.pack(subloras) + model.loras[module_name] = lora + return model + + LoRAModelManager.create_dummy_lora = _patched_create_dummy_lora + + # Monkeypatch LoadLoRAAdapter to allow loading the same adapter multiple times # TODO: may be removable if we pass load_inplace=True (supported since vLLM 0.18, PR #31326) def monkey_patch_load_lora_adapter():