Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 120 additions & 0 deletions src/prime_rl/inference/patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def transformers_v5_compat():

_patch_qwen35_lora()
_patch_lora_key_prefix()
monkey_patch_fused_moe_dummy_lora()
monkey_patch_dp_engine_core_pause_resume_deadlock()
monkey_patch_vllm_layerwise_reload_alias_buffers()

Expand Down Expand Up @@ -247,6 +248,125 @@ def check_unexpected_modules(modules: dict):
LoRAModel.from_local_checkpoint = classmethod(_patched_from_local_checkpoint)


def monkey_patch_fused_moe_dummy_lora():
"""Fix vLLM dummy LoRA warmup for FusedMoEWithLoRA.

In vLLM 0.21.0, ``LoRAModelManager.create_dummy_lora`` narrows the
replacement list for ``FusedMoEWithLoRA`` to the local EP experts, then
rebuilds that list from the original ``n_slices`` value. That can index past
``module.lora_a_stacked`` during profiling before the server is ready.
"""
from vllm.lora.layers.base import BaseLayerWithLoRA
from vllm.lora.lora_model import LoRAModel
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.model_manager import LoRAModelManager

def _patched_create_dummy_lora(
self: LoRAModelManager,
lora_id: int,
rank: int,
embedding_modules: dict[str, str] | None = None,
) -> LoRAModel:
model = LoRAModel(lora_id, rank, {})
for module_name, module in self.model.named_modules():
if (
not self._match_target_modules(module_name)
or not isinstance(module, BaseLayerWithLoRA)
or self._get_punica_wrapper(module_name) is None
):
continue

parts = module_name.split(".")
if module_name not in self.packed_modules:
assert embedding_modules is not None
if parts[-1] in embedding_modules:
if parts[-1] == "lm_head":
input_dim = module.lora_a_stacked[0].shape[-1]
output_dim = module.lora_b_stacked[0].shape[-2]
else:
input_dim = (
module.base_layer.org_vocab_size
if hasattr(module.base_layer, "org_vocab_size")
else module.base_layer.weight.shape[1]
)
output_dim = (
module.base_layer.embedding_dim
if hasattr(module.base_layer, "embedding_dim")
else module.base_layer.weight.shape[0]
)
lora = LoRALayerWeights.create_dummy_lora_weights(
module_name,
input_dim,
output_dim,
rank,
module.lora_a_stacked[0].dtype,
"cpu",
)
model.loras[module_name] = lora
elif module.__class__.__name__ == "FusedMoE3DWithLoRA":
lora = LoRALayerWeights.create_dummy_lora_weights(
module_name,
module.w2_input_size,
module.w2_output_size,
rank * module.w2_lora_a_stacked[0].shape[1],
module.w2_lora_a_stacked[0].dtype,
"cpu",
)
model.loras[module_name] = lora
lora = LoRALayerWeights.create_dummy_lora_weights(
module_name,
module.w13_input_size,
module.w13_output_size,
rank * module.w13_lora_a_stacked[0].shape[1],
module.w13_lora_a_stacked[0].dtype,
"cpu",
)
model.loras[module_name + ".base_layer"] = lora
else:
lora = LoRALayerWeights.create_dummy_lora_weights(
module_name,
module.lora_a_stacked[0].shape[-1],
module.lora_b_stacked[0].shape[-2],
rank,
module.lora_a_stacked[0].dtype,
"cpu",
)
model.loras[module_name] = lora
else:
replacements = self.packed_modules_mapping[parts[-1]]
n_slices = getattr(module, "n_slices", len(replacements))
if module.__class__.__name__ == "FusedMoEWithLoRA":
replacements = replacements[: len(module.lora_a_stacked) // self.lora_slots]
n_slices = len(replacements)

subloras: list[LoRALayerWeights | None] = []
if n_slices != len(replacements):
replacements = [f"slice_{i}" for i in range(n_slices)]
for i, replacement in enumerate(replacements):
lora = LoRALayerWeights.create_dummy_lora_weights(
module_name + "." + replacement,
module.lora_a_stacked[i].shape[-1],
module.lora_b_stacked[i].shape[-2],
rank,
module.lora_a_stacked[i].dtype,
"cpu",
)
subloras.append(lora)

if module.__class__.__name__ == "FusedMoEWithLoRA":
if self._is_non_gated_moe and len(subloras) > 0:
subloras = self._pad_lora_pairs_to_triplets(subloras)
lora = PackedLoRALayerWeights.pack_moe(
subloras, module_name, is_non_gated_moe=self._is_non_gated_moe
)
else:
lora = PackedLoRALayerWeights.pack(subloras)
model.loras[module_name] = lora
return model

LoRAModelManager.create_dummy_lora = _patched_create_dummy_lora


# Monkeypatch LoadLoRAAdapter to allow loading the same adapter multiple times
# TODO: may be removable if we pass load_inplace=True (supported since vLLM 0.18, PR #31326)
def monkey_patch_load_lora_adapter():
Expand Down