Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 0 additions & 33 deletions src/prime_rl/inference/patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ def transformers_v5_compat():
_patch_qwen35_lora()
_patch_lora_key_prefix()
monkey_patch_deep_gemm_silu_mul_quant_int64()
monkey_patch_vllm_layerwise_reload_alias_buffers()
monkey_patch_vllm_padded_input_scrub()
monkey_patch_return_routed_experts_with_nixl_connector()

Expand Down Expand Up @@ -67,38 +66,6 @@ def _post_init(config: VllmConfig):
logger.warning("Enabled vLLM routed-experts capture with NIXL connector patch.")


def monkey_patch_vllm_layerwise_reload_alias_buffers():
# vLLM's layerwise reload materializes each buffer as an independent tensor
# and then copies it back into the original kernel storage. When a buffer
# aliases a parameter (e.g. NemotronH Mamba's mixer.conv_weights, a view of
# mixer.conv1d.weight), the buffer copy stamps garbage into the parameter's
# storage *after* the parameter has been correctly reloaded. Skip the copy
# for any buffer that shares storage with a parameter; _place_kernel_tensors
# re-registers the original view, which trivially reflects the parameter.
# Remove this patch once https://github.com/vllm-project/vllm/pull/42481 is
# included in the vLLM release we pin/use.
from vllm.logger import init_logger
from vllm.model_executor.model_loader.reload import layerwise as reload_layerwise

logger = init_logger(__name__)

def _copy_and_restore_kernel_tensors(layer: torch.nn.Module, info: reload_layerwise.LayerReloadingInfo):
assert info.kernel_tensors is not None
parameters, buffers = info.kernel_tensors
param_storage_ptrs = {p.untyped_storage().data_ptr() for p in layer.parameters(recurse=True)}
for name, param in parameters.items():
param.data.copy_(getattr(layer, name))
for name, buffer in buffers.items():
if buffer.untyped_storage().data_ptr() in param_storage_ptrs:
continue
buffer.data.copy_(getattr(layer, name))

reload_layerwise._place_kernel_tensors(layer, info)

reload_layerwise._copy_and_restore_kernel_tensors = _copy_and_restore_kernel_tensors
logger.warning("Enabled vLLM layerwise reload alias-buffer patch.")


@triton.jit
def _silu_mul_per_token_group_quant_fp8_colmajor_int64_kernel(
y_ptr,
Expand Down
Loading