From f59c1e64e00ee0faf565910e1ad4e8511aaca8cf Mon Sep 17 00:00:00 2001 From: lichang Date: Mon, 1 Jun 2026 15:41:59 -0600 Subject: [PATCH 1/3] feat: Add FlexMLRT NPU vision backend for Qwen2.5-VL Introduce pluggable NPU vision support without scheduler or engine pipelining changes. Vision encoding runs synchronously on the NPU when VLLM_VISION_NPU_BACKEND=flexmlrt is set, keeping core v1 scheduling untouched for easier upstream review. Co-authored-by: Cursor --- .../passes/fusion/act_quant_fusion.py | 5 +- vllm/envs.py | 12 + vllm/model_executor/models/qwen2.py | 2 +- vllm/model_executor/models/qwen2_5_vl.py | 194 +++++++++++- vllm/model_executor/models/vision.py | 51 ++++ vllm/multimodal/utils.py | 100 ++++++- vllm/vision_npu/__init__.py | 13 + vllm/vision_npu/backend.py | 52 ++++ vllm/vision_npu/bridge/CMakeLists.txt | 45 +++ vllm/vision_npu/bridge/vision_flexmlrt.cpp | 218 ++++++++++++++ vllm/vision_npu/cpu_preprocess.py | 279 ++++++++++++++++++ vllm/vision_npu/flexmlrt_backend.py | 142 +++++++++ 12 files changed, 1099 insertions(+), 14 deletions(-) create mode 100644 vllm/vision_npu/__init__.py create mode 100644 vllm/vision_npu/backend.py create mode 100644 vllm/vision_npu/bridge/CMakeLists.txt create mode 100644 vllm/vision_npu/bridge/vision_flexmlrt.cpp create mode 100644 vllm/vision_npu/cpu_preprocess.py create mode 100644 vllm/vision_npu/flexmlrt_backend.py diff --git a/vllm/compilation/passes/fusion/act_quant_fusion.py b/vllm/compilation/passes/fusion/act_quant_fusion.py index 73234ec7920d..c55d90a684cf 100644 --- a/vllm/compilation/passes/fusion/act_quant_fusion.py +++ b/vllm/compilation/passes/fusion/act_quant_fusion.py @@ -39,7 +39,10 @@ if silu_and_mul_nvfp4_quant_supported: FUSED_OPS[kNvfp4Dynamic] = torch.ops._C.silu_and_mul_nvfp4_quant.default # noqa: E501 -if current_platform.is_cuda_alike(): +# Check if the per-block quant operation is available (newer ROCm/CUDA versions) +if current_platform.is_cuda_alike() and hasattr( + torch.ops._C, "silu_and_mul_per_block_quant" +): FUSED_OPS[kFp8Dynamic128Sym] = torch.ops._C.silu_and_mul_per_block_quant.default FUSED_OPS[kFp8Dynamic64Sym] = torch.ops._C.silu_and_mul_per_block_quant.default diff --git a/vllm/envs.py b/vllm/envs.py index 2448d3b5a873..e5079e6438c4 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -212,6 +212,10 @@ VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: int | None = None VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 480 VLLM_MORIIO_CONNECTOR_READ_MODE: bool = False + VLLM_VISION_NPU_BACKEND: str = "" + VLLM_VISION_NPU_CACHE: str | None = None + VLLM_VISION_NPU_DEVICE: str | None = None + VLLM_NPU_TIMING: bool = False VLLM_MORIIO_QP_PER_TRANSFER: int = 1 VLLM_MORIIO_POST_BATCH_SIZE: int = -1 VLLM_MORIIO_NUM_WORKERS: int = 1 @@ -1744,6 +1748,14 @@ def _get_or_set_default() -> str: # Disable PDL for LoRA, as enabling PDL with LoRA on SM100 causes # Triton compilation to fail. "VLLM_LORA_DISABLE_PDL": lambda: bool(int(os.getenv("VLLM_LORA_DISABLE_PDL", "0"))), + # NPU vision backend to use (e.g., "flexmlrt" for FlexMLRT backend) + "VLLM_VISION_NPU_BACKEND": lambda: os.getenv("VLLM_VISION_NPU_BACKEND", ""), + # Path to NPU model cache directory (required for FlexMLRT backend) + "VLLM_VISION_NPU_CACHE": lambda: os.getenv("VLLM_VISION_NPU_CACHE"), + # NPU device name (e.g., "stx" for Strix, "phx" for Phoenix) + "VLLM_VISION_NPU_DEVICE": lambda: os.getenv("VLLM_VISION_NPU_DEVICE"), + # Enable NPU timing debug logs + "VLLM_NPU_TIMING": lambda: os.getenv("VLLM_NPU_TIMING", "0") == "1", # Enable CUDA compatibility mode for datacenter GPUs with older # driver versions than the CUDA toolkit major version of vLLM. "VLLM_ENABLE_CUDA_COMPATIBILITY": lambda: ( diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 27aa6175b9bc..90083bb87451 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -416,7 +416,7 @@ def __init__( else: self.norm = PPMissingLayer() - def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + def embed_input_ids(self, input_ids: torch.Tensor, **kwargs) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index c42a11686e47..df54948c43fa 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -581,18 +581,40 @@ def __init__( ) -> None: super().__init__() + # Store minimal config needed for both NPU and PyTorch paths + self.out_hidden_size = vision_config.out_hidden_size + self.spatial_merge_size = vision_config.spatial_merge_size + self.spatial_merge_unit = self.spatial_merge_size**2 + + # Check NPU backend before creating PyTorch modules + from vllm.model_executor.models.vision import ( + get_npu_vision_backend, + use_npu_vision_backend, + ) + + if use_npu_vision_backend(): + try: + self.npu_backend = get_npu_vision_backend() + logger.info("[Qwen2.5VL] Using NPU vision backend") + return + except Exception as e: + logger.error("[Qwen2.5VL] NPU backend init failed: %s", e) + raise RuntimeError( + f"NPU vision backend initialization failed: {e}. " + "Set VLLM_VISION_NPU_BACKEND='' to use PyTorch backend." + ) from e + + self.npu_backend = None patch_size = vision_config.patch_size temporal_patch_size = vision_config.temporal_patch_size in_channels = vision_config.in_channels depth = vision_config.depth self.hidden_size = vision_config.hidden_size self.num_heads = vision_config.num_heads - self.out_hidden_size = vision_config.out_hidden_size # args for get_window_index_thw self.window_size = vision_config.window_size self.patch_size = vision_config.patch_size - self.spatial_merge_size = vision_config.spatial_merge_size self.fullatt_block_indexes = vision_config.fullatt_block_indexes self.spatial_merge_unit = self.spatial_merge_size**2 self.patch_embed = Qwen2_5_VisionPatchEmbed( @@ -653,11 +675,22 @@ def __init__( @property def dtype(self) -> torch.dtype: - return self.patch_embed.proj.weight.dtype + if hasattr(self, "npu_backend") and self.npu_backend is not None: + return torch.bfloat16 + if hasattr(self, "patch_embed"): + return self.patch_embed.proj.weight.dtype + # Safe fallback if neither exists + return torch.bfloat16 @property def device(self) -> torch.device: - return self.patch_embed.proj.weight.device + if hasattr(self, "npu_backend") and self.npu_backend is not None: + # NPU outputs are on CPU, transfer to GPU happens in forward + return torch.device("cpu") + if hasattr(self, "patch_embed"): + return self.patch_embed.proj.weight.device + # Safe fallback + return torch.device("cpu") def rotary_pos_emb_thw(self, t, h, w): hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) @@ -787,6 +820,94 @@ def forward( x: torch.Tensor, grid_thw: list[list[int]], ) -> torch.Tensor: + # Dispatch to NPU or PyTorch backend + if hasattr(self, "npu_backend") and self.npu_backend is not None: + return self._forward_npu(x, grid_thw) + else: + return self._forward_pytorch(x, grid_thw) + + def _forward_npu( + self, pixel_values: torch.Tensor, grid_thw: list[list[int]] + ) -> torch.Tensor: + """Forward pass using NPU backend.""" + import logging + import time + + import numpy as np + + logger = logging.getLogger(__name__) + + # Convert PyTorch → NumPy (handle bfloat16 by converting to float32 first) + if pixel_values.dtype == torch.bfloat16: + pixel_values_np = pixel_values.cpu().float().numpy() + else: + pixel_values_np = pixel_values.cpu().numpy().astype(np.float32) + grid_thw_np = np.array(grid_thw, dtype=np.int64) + + # Run NPU inference + embeddings_np = self.npu_backend.forward(pixel_values_np, grid_thw_np) + + # Convert back to PyTorch and transfer to GPU for LLM + import vllm.envs as envs + + if envs.VLLM_NPU_TIMING: + gpu_transfer_start = time.monotonic() + embeddings = torch.from_numpy(embeddings_np).to( + device="cuda", dtype=torch.bfloat16 + ) + gpu_transfer_ms = (time.monotonic() - gpu_transfer_start) * 1000 + logger.debug( + "[NPU Timing] CPU→GPU transfer: %.2fms (%.2f MB)", + gpu_transfer_ms, + embeddings_np.nbytes / 1024**2, + ) + logger.debug("[Vision→LLM] Vision embeddings shape: %s", embeddings.shape) + else: + embeddings = torch.from_numpy(embeddings_np).to( + device="cuda", dtype=torch.bfloat16 + ) + + # NPU model outputs compressed tokens but vLLM expects uncompressed + # count. We need to pad/repeat to match expected count based on grid_thw + actual_tokens = embeddings.shape[0] + merge_size = self.spatial_merge_size + expected_tokens_per_image = [ + (t * h * w) // (merge_size * merge_size) for t, h, w in grid_thw + ] + total_expected = sum(expected_tokens_per_image) + + if actual_tokens != total_expected: + logger.warning( + "[NPU] Token count mismatch: NPU output %s tokens, " + "but vLLM expects %s based on grid_thw. " + "Repeating tokens to match expected count.", + actual_tokens, + total_expected, + ) + repeat_factor = total_expected / actual_tokens + if repeat_factor == int(repeat_factor): + embeddings = embeddings.repeat_interleave(int(repeat_factor), dim=0) + else: + embeddings = embeddings.unsqueeze(0).unsqueeze(0) + embeddings = torch.nn.functional.interpolate( + embeddings, + size=(total_expected, embeddings.shape[-1]), + mode="nearest", + ) + embeddings = embeddings.squeeze(0).squeeze(0) + + logger.debug( + "[NPU] Padded from %s to %s tokens", actual_tokens, embeddings.shape[0] + ) + + return embeddings + + def _forward_pytorch( + self, + x: torch.Tensor, + grid_thw: list[list[int]], + ) -> torch.Tensor: + """Original PyTorch forward pass.""" # patchify seq_len, _ = x.size() rotary_pos_emb_cos = [] @@ -889,6 +1010,12 @@ def forward( return hidden_states def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + if self.npu_backend is not None: + logger.info( + "[Qwen2.5VL Vision] Skipping weight loading (using NPU backend)" + ) + return set() + stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("attn.qkv.", "attn.q.", "q"), @@ -1231,8 +1358,25 @@ def _process_image_input( image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list) # Split concatenated embeddings for each image item. - merge_size = self.visual.spatial_merge_size - sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist() + # When using NPU backend, merge is already done in NPU, so use actual + # output size + if hasattr(self.visual, "npu_backend") and self.visual.npu_backend is not None: + # NPU backend already did spatial merging - use actual output sizes + # For single image: sizes = [actual_num_tokens] + # For batched images: split based on actual output + num_images = len(grid_thw_list) + if num_images == 1: + # Single image - return the whole embedding + sizes = [image_embeds.shape[0]] + else: + # Multiple images - need to split based on actual grid sizes + # Each image: (T*H*W) // (merge_size^2) tokens after NPU + merge_size = self.visual.spatial_merge_size + sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist() + else: + # PyTorch backend - calculate expected size + merge_size = self.visual.spatial_merge_size + sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist() return image_embeds.split(sizes) def _postprocess_image_embeds_evs( @@ -1495,6 +1639,22 @@ def compute_logits( return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + if hasattr(self.visual, "npu_backend") and self.visual.npu_backend is not None: + logger.info( + "[Qwen2.5VL Model] Filtering out visual weights (using NPU backend)" + ) + filtered_weights = [] + visual_weight_count = 0 + for name, weight in weights: + if name.startswith("visual."): + visual_weight_count += 1 + continue + filtered_weights.append((name, weight)) + logger.info( + "[Qwen2.5VL Model] Skipped %s visual weights", visual_weight_count + ) + weights = filtered_weights + loader = AutoWeightsLoader(self) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) @@ -1526,3 +1686,25 @@ def get_num_mm_connector_tokens( vision_config = hf_config.vision_config merge_size = vision_config.spatial_merge_size return num_vision_tokens // merge_size**2 + + def embed_input_ids( + self, + input_ids: torch.Tensor, + multimodal_embeddings: tuple[torch.Tensor, ...] | None = None, + is_multimodal: torch.Tensor | None = None, + ) -> torch.Tensor: + """Embed token ids and merge multimodal embeddings (V1 MM path).""" + inputs_embeds = self.language_model.model.embed_input_ids(input_ids) + if ( + multimodal_embeddings is not None + and is_multimodal is not None + and len(multimodal_embeddings) > 0 + ): + from vllm.model_executor.models.utils import _merge_multimodal_embeddings + + inputs_embeds = _merge_multimodal_embeddings( + inputs_embeds, + multimodal_embeddings, + is_multimodal, + ) + return inputs_embeds diff --git a/vllm/model_executor/models/vision.py b/vllm/model_executor/models/vision.py index e6a243006759..c01f291fb07b 100644 --- a/vllm/model_executor/models/vision.py +++ b/vllm/model_executor/models/vision.py @@ -601,3 +601,54 @@ def get_llm_pos_ids_for_vision( llm_pos_ids_list.append(_llm_pos_ids + start_idx) llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1) return llm_pos_ids + + +# --------------------------------------------------------------------------- +# NPU Vision Backend Support +# --------------------------------------------------------------------------- + + +def use_npu_vision_backend() -> bool: + """Check if NPU backend is enabled for vision processing. + + Returns: + True if VLLM_VISION_NPU_BACKEND environment variable is set to + a supported backend (flexmlrt), False otherwise. + """ + import vllm.envs as envs + + backend = ( + envs.VLLM_VISION_NPU_BACKEND.lower() if envs.VLLM_VISION_NPU_BACKEND else "" + ) + return backend == "flexmlrt" + + +def get_npu_vision_backend(): + """Get NPU vision backend instance if enabled. + + Returns: + NPUVisionBackend instance if NPU backend is enabled, None otherwise. + + Raises: + ValueError: If backend name is recognized but initialization fails. + ImportError: If backend dependencies are not available. + """ + import vllm.envs as envs + + backend_name = ( + envs.VLLM_VISION_NPU_BACKEND.lower() if envs.VLLM_VISION_NPU_BACKEND else "" + ) + + if backend_name == "flexmlrt": + model_cache = envs.VLLM_VISION_NPU_CACHE + if not model_cache: + raise ValueError( + "VLLM_VISION_NPU_CACHE must be set when using FlexMLRT backend" + ) + device_name = envs.VLLM_VISION_NPU_DEVICE or "stx" + + from vllm.vision_npu.flexmlrt_backend import FlexMLRTVisionBackend + + return FlexMLRTVisionBackend(model_cache, device_name) + + return None diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index 2d321cb67b4e..0f7146eba3aa 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -30,6 +30,21 @@ torch = LazyLoader("torch", globals(), "torch") +def _is_npu_vision_backend() -> bool: + """Check if NPU vision backend is active (requires per-request processing). + + NPU backends like FlexMLRT have fixed input size requirements and cannot + batch vision inputs from multiple requests together. This function detects + when NPU backend is being used so we can apply special handling. + """ + import vllm.envs as envs + + backend = ( + envs.VLLM_VISION_NPU_BACKEND.lower() if envs.VLLM_VISION_NPU_BACKEND else "" + ) + return backend == "flexmlrt" + + def encode_audio_base64( audio: np.ndarray, sampling_rate: int, @@ -225,6 +240,10 @@ def group_and_batch_mm_kwargs( To simplify the implementation of `embed_multimodal`, we add another restriction that the items in a batch must belong to the same modality. + Special handling for NPU backends: vision inputs are NOT batched across + requests to support hardware with fixed input sizes (e.g., FlexMLRT NPU). + Standard GPU backends use normal batching behavior (unchanged). + Args: mm_kwargs: List of `(modality, item)`. device: The device to place the grouped tensors on. @@ -236,15 +255,84 @@ def group_and_batch_mm_kwargs( - `kwargs` is a dictionary of keyword arguments to pass to the model; - `num_items` is the corresponding number of items. """ + import logging + import threading + from datetime import datetime + + import vllm.envs as envs + + logger = logging.getLogger(__name__) + + # Auto-detect NPU backend for special handling + using_npu = _is_npu_vision_backend() + + if using_npu and envs.VLLM_NPU_TIMING: + timestamp = datetime.now().strftime("%H:%M:%S.%f")[:-3] + thread_id = threading.get_ident() + num_items_total = len(mm_kwargs) + logger.debug( + "[MM Batching] %s Thread-%s: Processing %s items (NPU mode)", + timestamp, + thread_id, + num_items_total, + ) + for modality, group in groupby(mm_kwargs, key=lambda x: x[0]): items_lst = [item for _, item in group] - for num_items, mm_kwargs_batch in group_and_batch_mm_items( - items_lst, - device=device, - pin_memory=pin_memory, - ): - yield modality, num_items, mm_kwargs_batch + # NPU path: process each request separately (no cross-request batching) + is_vision_on_npu = using_npu and modality in ("image", "video") + + if is_vision_on_npu: + # Debug: Log that we're using NPU single-item batching + if envs.VLLM_NPU_TIMING: + timestamp = datetime.now().strftime("%H:%M:%S.%f")[:-3] + logger.debug( + "[MM Batching] %s Thread-%s: NPU path - " + "yielding %s single-item batches for %s", + timestamp, + threading.get_ident(), + len(items_lst), + modality, + ) + + # Yield single-item batches to maintain fixed input size + for idx, item in enumerate(items_lst): + mm_kwargs_batch = _batch_mm_items( + [item], + device=device, + pin_memory=pin_memory, + ) + + if envs.VLLM_NPU_TIMING: + timestamp = datetime.now().strftime("%H:%M:%S.%f")[:-3] + logger.debug( + "[MM Batching] %s Thread-%s: Yielding item %s/%s", + timestamp, + threading.get_ident(), + idx + 1, + len(items_lst), + ) + + yield modality, 1, mm_kwargs_batch + else: + # Standard GPU path: original batching logic (unchanged) + if envs.VLLM_NPU_TIMING: + timestamp = datetime.now().strftime("%H:%M:%S.%f")[:-3] + logger.debug( + "[MM Batching] %s Thread-%s: GPU path - " + "using standard batching for %s", + timestamp, + threading.get_ident(), + modality, + ) + + for num_items, mm_kwargs_batch in group_and_batch_mm_items( + items_lst, + device=device, + pin_memory=pin_memory, + ): + yield modality, num_items, mm_kwargs_batch @deprecated( diff --git a/vllm/vision_npu/__init__.py b/vllm/vision_npu/__init__.py new file mode 100644 index 000000000000..e99e748abeb8 --- /dev/null +++ b/vllm/vision_npu/__init__.py @@ -0,0 +1,13 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Vision NPU backend infrastructure for vLLM. + +Provides pluggable NPU backends for vision processing in multimodal models. +""" + +from .backend import NPUVisionBackend +from .flexmlrt_backend import FlexMLRTVisionBackend + +__all__ = ["NPUVisionBackend", "FlexMLRTVisionBackend"] diff --git a/vllm/vision_npu/backend.py b/vllm/vision_npu/backend.py new file mode 100644 index 000000000000..8d953dca2e05 --- /dev/null +++ b/vllm/vision_npu/backend.py @@ -0,0 +1,52 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Abstract base class for vision NPU backends. +""" + +from abc import ABC, abstractmethod + +import numpy as np + + +class NPUVisionBackend(ABC): + """Base class for vision processing NPU backends. + + This abstract class defines the interface that all NPU vision backends + must implement. Different NPU implementations (FlexMLRT, ONNX Runtime, etc.) + can subclass this to provide hardware-accelerated vision processing. + """ + + @abstractmethod + def __init__(self, model_cache_path: str, device_name: str = "stx"): + """Load vision model onto NPU. + + Args: + model_cache_path: Path to pre-compiled NPU model cache + device_name: NPU device identifier (e.g., "stx" for Strix) + """ + pass + + @abstractmethod + def forward(self, pixel_values: np.ndarray, grid_thw: np.ndarray) -> np.ndarray: + """Run vision encoding on NPU. + + Args: + pixel_values: Input pixel data [seq_len, feature_dim] float32 + grid_thw: Grid dimensions [num_images, 3] int64 (temporal, height, width) + + Returns: + embeddings: Vision embeddings [merged_seq_len, hidden_dim] float32 + """ + pass + + @property + @abstractmethod + def output_dim(self) -> int: + """Output embedding dimension. + + Returns: + Hidden dimension of output embeddings (e.g., 3584 for Qwen2.5-VL) + """ + pass diff --git a/vllm/vision_npu/bridge/CMakeLists.txt b/vllm/vision_npu/bridge/CMakeLists.txt new file mode 100644 index 000000000000..a7a9b755b351 --- /dev/null +++ b/vllm/vision_npu/bridge/CMakeLists.txt @@ -0,0 +1,45 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +cmake_minimum_required(VERSION 3.18) +project(_vision_flexmlrt) + +# Find Python and pybind11 +find_package(Python REQUIRED COMPONENTS Interpreter Development) +find_package(pybind11 REQUIRED) + +# FlexMLRT paths (must be configured via command line or environment variables) +set(FLEXMLRT_INCLUDE_DIR "" CACHE PATH "FlexMLRT include directory") +set(FLEXMLRT_LIB_DIR "" CACHE PATH "FlexMLRT library directory") + +# Check if FlexMLRT paths are configured +if(NOT FLEXMLRT_INCLUDE_DIR OR NOT FLEXMLRT_LIB_DIR) + message(FATAL_ERROR + "FlexMLRT paths not configured. Please specify:\n" + " -DFLEXMLRT_INCLUDE_DIR=/path/to/flexmlRT/include\n" + " -DFLEXMLRT_LIB_DIR=/path/to/flexmlRT/build/lib\n" + "Or set environment variables:\n" + " export FLEXMLRT_INCLUDE_DIR=/path/to/flexmlRT/include\n" + " export FLEXMLRT_LIB_DIR=/path/to/flexmlRT/build/lib") +endif() + +# Verify paths exist +if(NOT EXISTS "${FLEXMLRT_INCLUDE_DIR}") + message(FATAL_ERROR "FlexMLRT include directory does not exist: ${FLEXMLRT_INCLUDE_DIR}") +endif() +if(NOT EXISTS "${FLEXMLRT_LIB_DIR}") + message(FATAL_ERROR "FlexMLRT library directory does not exist: ${FLEXMLRT_LIB_DIR}") +endif() + +message(STATUS "FlexMLRT include: ${FLEXMLRT_INCLUDE_DIR}") +message(STATUS "FlexMLRT library: ${FLEXMLRT_LIB_DIR}") + +# Create pybind11 module with CPU preprocessing support +pybind11_add_module(_vision_flexmlrt_cpu vision_flexmlrt.cpp) +target_include_directories(_vision_flexmlrt_cpu PRIVATE ${FLEXMLRT_INCLUDE_DIR}) +target_link_directories(_vision_flexmlrt_cpu PRIVATE ${FLEXMLRT_LIB_DIR}) +target_link_libraries(_vision_flexmlrt_cpu PRIVATE flexmlrt) +target_compile_features(_vision_flexmlrt_cpu PRIVATE cxx_std_17) + +# Install to parent directory (vllm/vision_npu/) +install(TARGETS _vision_flexmlrt_cpu LIBRARY DESTINATION ${CMAKE_SOURCE_DIR}/..) diff --git a/vllm/vision_npu/bridge/vision_flexmlrt.cpp b/vllm/vision_npu/bridge/vision_flexmlrt.cpp new file mode 100644 index 000000000000..b7016b15cce6 --- /dev/null +++ b/vllm/vision_npu/bridge/vision_flexmlrt.cpp @@ -0,0 +1,218 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright contributors to the vLLM project +// +// vision_flexmlrt.cpp — MODIFIED VERSION for CPU preprocessing +// +// This version accepts CPU-preprocessed [1073, 4, 1280] input instead of raw +// pixel_values + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace py = pybind11; + +// Debug logging gated by VLLM_LOGGING_LEVEL=DEBUG +inline bool is_vllm_debug() { + static int debug_enabled = -1; + if (debug_enabled == -1) { + const char* level = std::getenv("VLLM_LOGGING_LEVEL"); + debug_enabled = (level && std::strcmp(level, "DEBUG") == 0) ? 1 : 0; + } + return debug_enabled == 1; +} + +// Use stderr (not PySys_WriteStdout) so logging is safe while the GIL is +// released during model_->forward(). +#define DEBUG_LOG(expr) \ + do { \ + if (is_vllm_debug()) { \ + std::ostringstream oss; \ + oss << "[FlexMLRT] " << expr; \ + std::cerr << oss.str() << '\n'; \ + } \ + } while (0) + +// Build ErtIoTypeNew tensor descriptor +static flexmlrt::client::ErtIoTypeNew makeIO( + const std::string& name, int index, void* data, size_t size_bytes, + const std::string& dtype, const std::vector& shape) { + flexmlrt::client::ErtIoTypeNew io; + io.name = name; + io.idx = index; + io.data = data; + io.size = size_bytes; + io.type = dtype; + io.shape = shape; + return io; +} + +// VisionFlexMLRTModel with CPU preprocessing support +class VisionFlexMLRTModel { + public: + VisionFlexMLRTModel(const std::string& model_cache, + const std::string& device_name) + : device_name_(device_name) { + DEBUG_LOG(" VisionFlexMLRTModel constructor START"); + DEBUG_LOG(" model_cache: " << model_cache); + DEBUG_LOG(" device_name: " << device_name); + + // Create options object (will be destroyed after model creation) + flexmlrt::client::Options opts; + opts.modelPath = model_cache; + opts.deviceName = device_name; + opts.subgraphName = "0"; // Specify subgraph name explicitly + opts.executeMode = 2; // From test_generic line 446 + + DEBUG_LOG(" Creating FlexMLRT Model object..."); + try { + model_ = std::make_unique(opts); + DEBUG_LOG(" FlexMLRT Model object created"); + } catch (const std::exception& e) { + std::cerr << "[FlexMLRT ERROR] FlexMLRT Model creation threw exception: " + << e.what() << std::endl; + throw std::runtime_error( + std::string("Failed to load FlexMLRT vision model: ") + e.what()); + } + // opts goes out of scope here - memory automatically freed + + if (!model_->good()) { + std::cerr << "[FlexMLRT ERROR] model->good() returned false" << std::endl; + throw std::runtime_error( + "FlexMLRT vision model creation failed - check model cache and " + "device availability"); + } + DEBUG_LOG(" model->good() returned true"); + DEBUG_LOG(" VisionFlexMLRTModel constructor END (opts memory released)"); + } + + // Forward pass with CPU-preprocessed input [1073, 4, 1280] + py::array_t forward(py::array_t preprocessed_input) { + DEBUG_LOG(" forward() START (CPU-preprocessed input)"); + + auto buf = preprocessed_input.request(); + DEBUG_LOG(" Input ndim: " << buf.ndim); + + if (buf.ndim != 3) { + throw std::runtime_error( + "preprocessed_input must be 3D array [1073, 4, 1280]"); + } + + int64_t dim0 = buf.shape[0]; // 1073 + int64_t dim1 = buf.shape[1]; // 4 + int64_t dim2 = buf.shape[2]; // 1280 + + DEBUG_LOG(" Input shape: [" << dim0 << ", " << dim1 << ", " << dim2 << "]"); + + if (dim0 != 1073 || dim1 != 4 || dim2 != 1280) { + throw std::runtime_error( + "Expected input shape [1073, 4, 1280], got [" + std::to_string(dim0) + + ", " + std::to_string(dim1) + ", " + std::to_string(dim2) + "]"); + } + + // Build input tensors + std::vector ifms; + + // Input name from NPU partition ONNX: "/blocks/Gather_output_0" + ifms.push_back(makeIO("/blocks/Gather_output_0", 0, buf.ptr, + dim0 * dim1 * dim2 * sizeof(float), "float32", + {dim0, dim1, dim2})); + DEBUG_LOG(" Input tensor built: /blocks/Gather_output_0 [1073, 4, 1280]"); + + // Output tensor + // From NPU partition ONNX: "/merger/merger/mlp/mlp.2/Gemm_output_0" [1073, + // 3584] + int64_t out_dim0 = 1073; + int64_t out_dim1 = 3584; + + std::vector output_buf(out_dim0 * out_dim1); + std::vector ofms; + ofms.push_back(makeIO("/merger/merger/mlp/mlp.2/Gemm_output_0", 0, + output_buf.data(), output_buf.size() * sizeof(float), + "float32", {out_dim0, out_dim1})); + DEBUG_LOG( + " Output tensor built: /merger/merger/mlp/mlp.2/Gemm_output_0 [1073, " + "3584]"); + + std::vector wts; + + // Run NPU inference + DEBUG_LOG(" Calling model->forward()..."); + DEBUG_LOG(" Releasing GIL to allow GPU parallelization..."); + try { + // CRITICAL: Release GIL during NPU execution to allow GPU to run in + // parallel NPU inference takes ~11 seconds - other Python threads must be + // able to proceed + py::gil_scoped_release release; + model_->forward(ifms, ofms, wts); + // GIL automatically reacquired when 'release' goes out of scope + DEBUG_LOG(" model->forward() returned successfully (GIL reacquired)"); + } catch (const std::exception& e) { + std::cerr << "[FlexMLRT ERROR] model->forward() threw exception: " + << e.what() << std::endl; + throw std::runtime_error(std::string("FlexMLRT forward failed: ") + + e.what()); + } + + // Copy output to numpy array + DEBUG_LOG(" Copying output to numpy array..."); + py::array_t result({out_dim0, out_dim1}); + auto result_buf = result.request(); + std::memcpy(result_buf.ptr, output_buf.data(), + output_buf.size() * sizeof(float)); + + // Explicitly clear temporary buffers (helps with memory fragmentation) + output_buf.clear(); + output_buf.shrink_to_fit(); + ifms.clear(); + ofms.clear(); + + DEBUG_LOG(" forward() END (temporary buffers released)"); + + return result; + } + + int output_dim() const { + return 3584; // Fixed for Qwen2.5-VL + } + + private: + std::unique_ptr model_; + std::string device_name_; + // Removed unused members: + // - std::unique_ptr rai_loader_; (never initialized or used) + // - int output_dim_; (unused, output_dim() returns hardcoded 3584) +}; + +// pybind11 module +PYBIND11_MODULE(_vision_flexmlrt_cpu, m) { + m.doc() = "FlexMLRT vision model with CPU preprocessing support"; + + py::class_(m, "VisionFlexMLRTModel") + .def(py::init(), py::arg("model_cache"), + py::arg("device_name") = "stx", + "Load FlexMLRT vision model\n\n" + "Args:\n" + " model_cache: Path to VAIP model cache (vaiml_par_0 directory)\n" + " device_name: XRT device name (default: 'stx')") + .def("forward", &VisionFlexMLRTModel::forward, + py::arg("preprocessed_input"), + "Run vision encoding on NPU with CPU-preprocessed input\n\n" + "Args:\n" + " preprocessed_input: [1073, 4, 1280] float32 array " + "(CPU-preprocessed)\n\n" + "Returns:\n" + " embeddings: [1073, 3584] float32 array") + .def("output_dim", &VisionFlexMLRTModel::output_dim, + "Get output embedding dimension"); +} diff --git a/vllm/vision_npu/cpu_preprocess.py b/vllm/vision_npu/cpu_preprocess.py new file mode 100644 index 000000000000..ba6e9d3afbcc --- /dev/null +++ b/vllm/vision_npu/cpu_preprocess.py @@ -0,0 +1,279 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +CPU preprocessing operations for VitisAI-compiled vision models. + +This module implements the CPU operations that VitisAI ExecutionProvider +normally handles automatically. When using FlexMLRT directly, we must +manually implement these operations. + +For Qwen2.5-VL vision model: +- Input: pixel_values [4292, 1176] from HuggingFace processor +- Output: preprocessed [1073, 4, 1280] ready for NPU +- Postprocessing: Apply reverse_index Gather to NPU output +""" + +import logging + +import numpy as np +import torch + +logger = logging.getLogger(__name__) + + +class Qwen2_5_VL_CPUPreprocessor: + """CPU preprocessing for Qwen2.5-VL vision model before NPU execution.""" + + def __init__(self, model_cache_dir: str): + """ + Initialize CPU preprocessor with required parameters. + + Args: + model_cache_dir: Path to NPU model cache directory containing ONNX model + """ + import os + + import onnx + + # Load ONNX model to extract parameters + # model_cache_dir is typically: .../qwen2_5_vl_vision_stitched_7b/vaiml_par_0 + # We need to go up two levels to find the .onnx file + onnx_model_path = os.path.join( + os.path.dirname(os.path.dirname(model_cache_dir)), + "qwen2_5_vl_vision_stitched_7b.onnx", + ) + + if not os.path.exists(onnx_model_path): + logger.warning( + "[CPU Preprocess] ONNX not found at %s, trying alternative path", + onnx_model_path, + ) + # Alternative: look in parent directory + alt_path = os.path.join( + os.path.dirname(model_cache_dir), "qwen2_5_vl_vision_stitched_7b.onnx" + ) + if os.path.exists(alt_path): + onnx_model_path = alt_path + else: + raise FileNotFoundError( + f"Cannot find ONNX model at {onnx_model_path} or {alt_path}" + ) + + logger.info("[CPU Preprocess] Loading ONNX model from %s", onnx_model_path) + model = onnx.load(onnx_model_path) + graph = model.graph + + # Extract parameters from ONNX model + initializers = {init.name: init for init in graph.initializer} + + # Conv weights for patch embedding + if "patch_embed.proj.weight" in initializers: + weight_tensor = initializers["patch_embed.proj.weight"] + self.conv_weight = onnx.numpy_helper.to_array(weight_tensor) + logger.info( + "[CPU Preprocess] Loaded conv weight: %s", self.conv_weight.shape + ) + else: + raise ValueError("patch_embed.proj.weight not found in ONNX model") + + # Gather indices for window reordering + if "blocks.window_index" in initializers: + indices_tensor = initializers["blocks.window_index"] + self.window_index = onnx.numpy_helper.to_array(indices_tensor) + logger.info( + "[CPU Preprocess] Loaded window_index: %s", self.window_index.shape + ) + else: + raise ValueError("blocks.window_index not found in ONNX model") + + # Reverse index for final postprocessing + if "merger.reverse_index" in initializers: + reverse_tensor = initializers["merger.reverse_index"] + self.reverse_index = onnx.numpy_helper.to_array(reverse_tensor) + logger.info( + "[CPU Preprocess] Loaded reverse_index: %s", self.reverse_index.shape + ) + else: + raise ValueError("merger.reverse_index not found in ONNX model") + + logger.info("[CPU Preprocess] Initialized successfully") + + def preprocess(self, pixel_values: torch.Tensor) -> np.ndarray: + """ + Apply CPU preprocessing operations to pixel_values. + + Args: + pixel_values: [seq_len, feature_dim] float32 tensor from HF processor + Expected shape: [4292, 1176] + + Returns: + preprocessed: [1073, 4, 1280] float32 numpy array ready for NPU + """ + # Convert to numpy + if isinstance(pixel_values, torch.Tensor): + pixel_values_np = pixel_values.cpu().float().numpy() + else: + pixel_values_np = pixel_values.astype(np.float32) + + logger.info("[CPU Preprocess] Input shape: %s", pixel_values_np.shape) + + # Operation 1: Reshape to [batch, 3, 2, 14, 14] + # pixel_values [4292, 1176] → [4292, 3, 2, 14, 14] + x = pixel_values_np.reshape(-1, 3, 2, 14, 14) + + # Operation 2: Conv3D for patch embedding + # Input: [4292, 3, 2, 14, 14] + # Weight: [1280, 3, 2, 14, 14] + # Output: [4292, 1280, 1, 1, 1] + out_channels = self.conv_weight.shape[0] + batch_size = x.shape[0] + conv_out = np.zeros((batch_size, out_channels, 1, 1, 1), dtype=np.float32) + + # Naive implementation - can be optimized with torch.nn.functional.conv3d + for b in range(batch_size): + for oc in range(out_channels): + conv_out[b, oc, 0, 0, 0] = np.sum(x[b] * self.conv_weight[oc]) + + # Operation 3: Reshape to [4292, 1280] + x2 = conv_out.reshape(-1, 1280) + + # Operation 4: Reshape to [1073, 4, 1280] - merge patches 4x4 + x3 = x2.reshape(1073, 4, 1280) + + # Operation 5: Gather with window_index (reordering) + # Note: This maintains shape [1073, 4, 1280] + x4 = x3[self.window_index] + + logger.info("[CPU Preprocess] Output shape: %s", x4.shape) + return x4 + + def postprocess(self, npu_output: np.ndarray) -> np.ndarray: + """ + Apply CPU postprocessing to NPU output. + + Args: + npu_output: [1073, 3584] float32 array from NPU + + Returns: + final_output: [1073, 3584] float32 array after reverse_index reordering + """ + # Apply final Gather with reverse_index + reordered = npu_output[self.reverse_index] + logger.info( + "[CPU Postprocess] Applied reverse_index, shape: %s", reordered.shape + ) + return reordered + + +class Qwen2_5_VL_CPUPreprocessor_Optimized: + """Optimized version using torch for Conv3D.""" + + def __init__(self, model_cache_dir: str): + """Initialize with torch-based Conv3D for faster preprocessing.""" + import os + + import onnx + + onnx_model_path = os.path.join( + os.path.dirname(os.path.dirname(model_cache_dir)), + "qwen2_5_vl_vision_stitched_7b.onnx", + ) + + if not os.path.exists(onnx_model_path): + logger.warning( + "[CPU Preprocess Optimized] ONNX not found at %s, trying alternative", + onnx_model_path, + ) + alt_path = os.path.join( + os.path.dirname(model_cache_dir), "qwen2_5_vl_vision_stitched_7b.onnx" + ) + if os.path.exists(alt_path): + onnx_model_path = alt_path + else: + raise FileNotFoundError( + f"Cannot find ONNX model at {onnx_model_path} or {alt_path}" + ) + + logger.info( + "[CPU Preprocess Optimized] Loading ONNX model from %s", onnx_model_path + ) + model = onnx.load(onnx_model_path) + graph = model.graph + initializers = {init.name: init for init in graph.initializer} + + # Load parameters and convert to torch + weight_np = onnx.numpy_helper.to_array(initializers["patch_embed.proj.weight"]) + self.conv_weight = torch.from_numpy(weight_np).float() + + self.window_index = onnx.numpy_helper.to_array( + initializers["blocks.window_index"] + ) + self.reverse_index = onnx.numpy_helper.to_array( + initializers["merger.reverse_index"] + ) + + # Release ONNX model from memory (saves ~600 MB CPU RAM) + del model, graph, initializers, weight_np + import gc + + gc.collect() + logger.info( + "[CPU Preprocess Optimized] Initialized with torch Conv3D " + "(ONNX model released from memory)" + ) + + def preprocess(self, pixel_values: torch.Tensor) -> np.ndarray: + """Optimized preprocessing using torch.nn.functional.conv3d.""" + pixel_values = pixel_values.cpu().float() + + # Reshape to [batch, 3, 2, 14, 14] + x = pixel_values.reshape(-1, 3, 2, 14, 14) + + # Conv3D using torch (much faster than numpy) + import torch.nn.functional as F + + # Rearrange to [batch, channels, depth, height, width] + conv_out = F.conv3d( + x, self.conv_weight, bias=None, stride=(2, 14, 14), padding=(0, 0, 0) + ) # Output: [4292, 1280, 1, 1, 1] + + # Reshape to [4292, 1280] + x2 = conv_out.reshape(-1, 1280) + + # Reshape to [1073, 4, 1280] + x3 = x2.reshape(1073, 4, 1280) + + # Gather with window_index + x4_np = x3.numpy()[self.window_index] + + logger.info("[CPU Preprocess Optimized] Output shape: %s", x4_np.shape) + return x4_np + + def postprocess(self, npu_output: np.ndarray) -> np.ndarray: + """Apply reverse_index reordering.""" + return npu_output[self.reverse_index] + + +# Factory function to get appropriate preprocessor +def get_cpu_preprocessor(model_cache_dir: str, optimized: bool = True): + """ + Get CPU preprocessor for Qwen2.5-VL vision model. + + Args: + model_cache_dir: Path to NPU model cache + optimized: Use torch-based optimized version (default: True) + + Returns: + Preprocessor instance + """ + if optimized: + try: + return Qwen2_5_VL_CPUPreprocessor_Optimized(model_cache_dir) + except Exception as e: + logger.warning( + "Failed to load optimized preprocessor: %s, falling back to numpy", + e, + ) + return Qwen2_5_VL_CPUPreprocessor(model_cache_dir) + else: + return Qwen2_5_VL_CPUPreprocessor(model_cache_dir) diff --git a/vllm/vision_npu/flexmlrt_backend.py b/vllm/vision_npu/flexmlrt_backend.py new file mode 100644 index 000000000000..9e01d65af206 --- /dev/null +++ b/vllm/vision_npu/flexmlrt_backend.py @@ -0,0 +1,142 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +FlexMLRT-based vision NPU backend with CPU preprocessing. + +VitisAI-compiled models partition operations between CPU and NPU. This backend +implements the CPU preprocessing operations before calling FlexMLRT for NPU +execution, matching the behavior of VitisAI ExecutionProvider. +""" + +import contextlib +import logging +import time + +import numpy as np +import torch + +import vllm.envs as envs + +from .backend import NPUVisionBackend +from .cpu_preprocess import get_cpu_preprocessor + +logger = logging.getLogger(__name__) + +# Cache environment variables for performance (avoids repeated lookups) +VLLM_NPU_TIMING = envs.VLLM_NPU_TIMING + + +@contextlib.contextmanager +def npu_timing(operation: str, logger_obj=None): + """Zero-overhead timing for NPU operations when VLLM_NPU_TIMING=1. + + Args: + operation: Name of the operation being timed + logger_obj: Optional logger to use (defaults to module logger) + """ + if not VLLM_NPU_TIMING: + yield + return + + start = time.monotonic() + try: + yield + finally: + elapsed_ms = (time.monotonic() - start) * 1000 + log_func = logger_obj.info if logger_obj else logger.info + log_func("[NPU Timing] %s: %.2fms", operation, elapsed_ms) + + +class FlexMLRTVisionBackend(NPUVisionBackend): + """FlexMLRT implementation of NPU vision backend with CPU preprocessing. + + Uses AMD FlexMLRT library to run vision models on Ryzen AI NPU. + Implements CPU preprocessing operations that VitisAI EP normally handles. + """ + + def __init__(self, model_cache_path: str, device_name: str = "stx"): + """Initialize FlexMLRT vision model with CPU preprocessing. + + Args: + model_cache_path: Path to VAIP model cache (vaiml_par_0 directory) + device_name: XRT device name ("stx" for Strix, "phx" for Phoenix) + """ + from vllm.vision_npu._vision_flexmlrt_cpu import VisionFlexMLRTModel + + self.model = VisionFlexMLRTModel(model_cache_path, device_name) + + # Initialize CPU preprocessor + self.preprocessor = get_cpu_preprocessor(model_cache_path, optimized=True) + logger.info("[FlexMLRT Backend] Initialized with CPU preprocessing") + + def forward(self, pixel_values: np.ndarray, grid_thw: np.ndarray) -> np.ndarray: + """Run vision encoding with CPU preprocessing + NPU execution. + + Pipeline: + 1. CPU preprocessing: [4292, 1176] → [1073, 4, 1280] + 2. NPU execution: [1073, 4, 1280] → [1073, 3584] + 3. CPU postprocessing: Apply reverse_index reordering + + Args: + pixel_values: [seq_len, feature_dim] float32 array from HF processor + grid_thw: [num_images, 3] int64 array (unused for now) + + Returns: + embeddings: [merged_seq_len, hidden_dim] float32 array + """ + total_start = time.monotonic() if VLLM_NPU_TIMING else None + + # Convert numpy to torch for preprocessing + with npu_timing("NumPy→Torch conversion", logger): + if isinstance(pixel_values, np.ndarray): + pixel_values_torch = torch.from_numpy(pixel_values).float() + else: + pixel_values_torch = pixel_values.float() + + # Step 1: CPU preprocessing + logger.debug( + "[FlexMLRT Backend] Preprocessing input shape: %s", pixel_values.shape + ) + with npu_timing("CPU preprocessing (total)", logger): + preprocessed = self.preprocessor.preprocess(pixel_values_torch) + + # Step 2: NPU execution + logger.debug( + "[FlexMLRT Backend] Running NPU inference on shape: %s", + preprocessed.shape, + ) + with npu_timing("NPU inference", logger): + npu_output = self.model.forward(preprocessed) + + # Step 3: CPU postprocessing + logger.debug( + "[FlexMLRT Backend] Postprocessing NPU output shape: %s", npu_output.shape + ) + with npu_timing("CPU postprocessing", logger): + final_output = self.preprocessor.postprocess(npu_output) + + logger.debug("[FlexMLRT Backend] Final output shape: %s", final_output.shape) + + # Log total time and memory stats + if VLLM_NPU_TIMING and total_start is not None: + total_ms = (time.monotonic() - total_start) * 1000 + logger.info("[NPU Timing] Total vision pipeline: %.2fms", total_ms) + logger.info("[NPU Memory] Input: %.2f MB", pixel_values.nbytes / 1024**2) + logger.info( + "[NPU Memory] Preprocessed: %.2f MB", preprocessed.nbytes / 1024**2 + ) + logger.info("[NPU Memory] Output: %.2f MB", final_output.nbytes / 1024**2) + logger.info( + "[ViT Output] Shape: %s \u2192 %d patches \u00d7 %d embedding_dim", + final_output.shape, + final_output.shape[0], + final_output.shape[1], + ) + + return final_output + + @property + def output_dim(self) -> int: + """Get output embedding dimension from FlexMLRT model.""" + return self.model.output_dim() From 988be0bcdb3e6fee9b1fe6ac63c0ad5ca690307a Mon Sep 17 00:00:00 2001 From: lichang Date: Mon, 1 Jun 2026 15:42:34 -0600 Subject: [PATCH 2/3] feat: Add NPU+GPU async vision pipelining for v1 engine Layer async NPU vision pre-encoding on top of the FlexMLRT backend: vision scheduler in EngineCore, scheduler deferral when vision is not ready, and gpu_model_runner pre-encoding thread pool. Gated by VLLM_NPU_ASYNC_PIPELINE=1 (default off). Co-authored-by: Cursor --- vllm/envs.py | 3 + vllm/model_executor/models/vision.py | 11 +- vllm/v1/core/sched/scheduler.py | 40 ++ vllm/v1/engine/core.py | 95 +++- vllm/v1/engine/output_processor.py | 30 ++ vllm/v1/executor/uniproc_executor.py | 11 + vllm/v1/worker/gpu_model_runner.py | 711 ++++++++++++++++++++++++++- vllm/vision_npu/flexmlrt_backend.py | 200 ++++++++ 8 files changed, 1084 insertions(+), 17 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index e5079e6438c4..9d62308df601 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -215,6 +215,7 @@ VLLM_VISION_NPU_BACKEND: str = "" VLLM_VISION_NPU_CACHE: str | None = None VLLM_VISION_NPU_DEVICE: str | None = None + VLLM_NPU_ASYNC_PIPELINE: bool = False VLLM_NPU_TIMING: bool = False VLLM_MORIIO_QP_PER_TRANSFER: int = 1 VLLM_MORIIO_POST_BATCH_SIZE: int = -1 @@ -1754,6 +1755,8 @@ def _get_or_set_default() -> str: "VLLM_VISION_NPU_CACHE": lambda: os.getenv("VLLM_VISION_NPU_CACHE"), # NPU device name (e.g., "stx" for Strix, "phx" for Phoenix) "VLLM_VISION_NPU_DEVICE": lambda: os.getenv("VLLM_VISION_NPU_DEVICE"), + # Enable async pipelining of NPU vision encoding with GPU LLM inference + "VLLM_NPU_ASYNC_PIPELINE": lambda: os.getenv("VLLM_NPU_ASYNC_PIPELINE", "0") == "1", # Enable NPU timing debug logs "VLLM_NPU_TIMING": lambda: os.getenv("VLLM_NPU_TIMING", "0") == "1", # Enable CUDA compatibility mode for datacenter GPUs with older diff --git a/vllm/model_executor/models/vision.py b/vllm/model_executor/models/vision.py index c01f291fb07b..0ab01e8d35a2 100644 --- a/vllm/model_executor/models/vision.py +++ b/vllm/model_executor/models/vision.py @@ -628,6 +628,7 @@ def get_npu_vision_backend(): Returns: NPUVisionBackend instance if NPU backend is enabled, None otherwise. + Returns AsyncFlexMLRTVisionBackend if VLLM_NPU_ASYNC_PIPELINE=1. Raises: ValueError: If backend name is recognized but initialization fails. @@ -647,8 +648,14 @@ def get_npu_vision_backend(): ) device_name = envs.VLLM_VISION_NPU_DEVICE or "stx" - from vllm.vision_npu.flexmlrt_backend import FlexMLRTVisionBackend + # Use async backend if pipelining is enabled + if envs.VLLM_NPU_ASYNC_PIPELINE: + from vllm.vision_npu.flexmlrt_backend import AsyncFlexMLRTVisionBackend - return FlexMLRTVisionBackend(model_cache, device_name) + return AsyncFlexMLRTVisionBackend(model_cache, device_name) + else: + from vllm.vision_npu.flexmlrt_backend import FlexMLRTVisionBackend + + return FlexMLRTVisionBackend(model_cache, device_name) return None diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 40b5899f0457..6fcc79834a71 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -103,6 +103,12 @@ def __init__( # Scheduling constraints. self.max_num_running_reqs = self.scheduler_config.max_num_seqs + self.enable_hybrid_pipeline = ( + envs.VLLM_NPU_ASYNC_PIPELINE + and envs.VLLM_VISION_NPU_BACKEND.lower() == "flexmlrt" + ) + # Set during schedule() when a request is deferred for NPU vision. + self.waiting_on_vision_encoding = False self.max_num_scheduled_tokens = ( self.scheduler_config.max_num_scheduled_tokens if self.scheduler_config.max_num_scheduled_tokens @@ -357,6 +363,8 @@ def schedule(self) -> SchedulerOutput: # chunked prefills, prefix caching, speculative decoding, # and the "jump decoding" optimization in the future. + self.waiting_on_vision_encoding = False + scheduled_new_reqs: list[Request] = [] scheduled_resumed_reqs: list[Request] = [] scheduled_running_reqs: list[Request] = [] @@ -574,6 +582,25 @@ def schedule(self) -> SchedulerOutput: request = request_queue.peek_request() request_id = request.request_id + # HYBRID PIPELINING: defer prefill until NPU vision is ready. + if self.enable_hybrid_pipeline and self.max_num_running_reqs == 1: + needs_vision = ( + request.num_computed_tokens == 0 and request.mm_features + ) + if needs_vision: + from vllm.v1.engine.core import ( + _VISION_PREENCODING_CACHE, + is_vision_preencoding_ready, + ) + + if not is_vision_preencoding_ready( + request_id, _VISION_PREENCODING_CACHE + ): + self.waiting_on_vision_encoding = True + request_queue.pop_request() + step_skipped_waiting.prepend_request(request) + continue + # try to promote blocked statuses while traversing skipped queue. if self._is_blocked_waiting_status( request.status @@ -814,6 +841,19 @@ def schedule(self) -> SchedulerOutput: continue self.running.append(request) + + if self.enable_hybrid_pipeline and self.max_num_running_reqs == 1: + is_vision_phase = ( + request.num_computed_tokens == 0 and request.mm_features + ) + phase_name = "VISION" if is_vision_phase else "LLM" + logger.debug( + "[Hybrid Scheduler] Scheduled %s in %s phase (running: %d)", + request.request_id, + phase_name, + len(self.running), + ) + if self.log_stats: request.record_event( EngineCoreEventType.SCHEDULED, scheduled_timestamp diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 6bf6910cc6f2..611395f60843 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -85,6 +85,34 @@ _R = TypeVar("_R") # Return type for collective_rpc +# Global vision pre-encoding cache (shared between EngineCore and workers) +_VISION_PREENCODING_CACHE: dict[str, Any] = {} + +# Busy-loop backoff while deferred requests wait on NPU vision (seconds). +_VISION_POLL_SLEEP_S = 0.02 +_DEFAULT_BUSY_LOOP_SLEEP_S = 0.001 + + +def is_vision_preencoding_ready( + request_id: str, cache: dict[str, Any] | None = None +) -> bool: + """Return True when background vision encoding finished for a request.""" + if cache is None: + cache = _VISION_PREENCODING_CACHE + cached = cache.get(request_id) + if cached == "COMPLETED": + return True + if cached is None: + return False + done = getattr(cached, "done", None) + return callable(done) and done() + + +def _request_has_vision_mm(request: Any) -> bool: + if not request.mm_features: + return False + return any(f.modality in ("image", "video") for f in request.mm_features) + class EngineCore: """Inner loop of vLLM's Engine.""" @@ -398,6 +426,59 @@ def log_iteration_details(self, scheduler_output: SchedulerOutput): ) self._iteration_index += 1 + def _schedule_waiting_vision(self) -> None: + """Vision Scheduler: Proactively trigger pre-encoding for waiting requests. + + This is the key to enabling pipelining with max-num-seqs=1: + - Core scheduler only schedules 1 LLM at a time (max-num-seqs=1) + - Vision scheduler processes ALL waiting requests' vision independently + - Request 2's vision can process while Request 1's LLM runs + """ + if not ( + envs.VLLM_NPU_ASYNC_PIPELINE + and envs.VLLM_VISION_NPU_BACKEND.lower() == "flexmlrt" + ): + return + + try: + waiting_requests = list(self.scheduler.waiting) # type: ignore[attr-defined] + except Exception as e: + logger.exception("[Vision Scheduler] Error accessing waiting queue: %s", e) + return + + if not waiting_requests: + return + + # Skip the scan when every waiting vision request is already submitted. + pending_submit = False + for request in waiting_requests: + if not _request_has_vision_mm(request): + continue + req_id = request.request_id + if req_id not in _VISION_PREENCODING_CACHE: + pending_submit = True + break + + if not pending_submit: + return + + for request in waiting_requests: + if not _request_has_vision_mm(request): + continue + + req_id = request.request_id + if req_id in _VISION_PREENCODING_CACHE: + continue + + has_mm_hash = any(mm_feature.mm_hash for mm_feature in request.mm_features) + if not has_mm_hash: + continue + + logger.debug( + "[Vision Scheduler] Submitting pre-encoding for request %s", req_id + ) + self.model_executor.submit_vision_encoding(req_id, request.mm_features) + def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]: """Schedule, execute, and make output. @@ -409,6 +490,10 @@ def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]: # or finished and not yet removed from the batch. if not self.scheduler.has_requests(): return {}, False + + # Vision pre-encoding for waiting requests (overlaps with running LLM). + self._schedule_waiting_vision() + scheduler_output = self.scheduler.schedule() future = self.model_executor.execute_model(scheduler_output, non_block=True) grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output) @@ -467,7 +552,12 @@ def step_with_batch_queue( model_executed = False deferred_scheduler_output = None if self.scheduler.has_requests(): + # VISION SCHEDULER: Proactively trigger pre-encoding + # Request 2's vision can start while Request 1's LLM runs + self._schedule_waiting_vision() + scheduler_output = self.scheduler.schedule() + with self.log_error_detail(scheduler_output): exec_future = self.model_executor.execute_model( scheduler_output, non_block=True @@ -1211,7 +1301,10 @@ def _process_engine_step(self) -> bool: # background threads (like NIXL handshake) to make progress. # Without this, the tight polling loop can starve background threads. if not model_executed and self.scheduler.has_unfinished_requests(): - time.sleep(0.001) + if getattr(self.scheduler, "waiting_on_vision_encoding", False): + time.sleep(_VISION_POLL_SLEEP_S) + else: + time.sleep(_DEFAULT_BUSY_LOOP_SLEEP_S) return model_executed diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 1ae89ae19680..7ab5df1cc65d 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -10,6 +10,7 @@ import numpy as np import torch +from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.outputs import ( STREAM_FINISHED, @@ -38,6 +39,8 @@ SchedulerStats, ) +logger = init_logger(__name__) + # shared empty CPU tensor used as a placeholder pooling output EMPTY_CPU_TENSOR = torch.empty(0, device="cpu") @@ -678,6 +681,33 @@ def process_outputs( self._update_stats_from_finished( req_state, finish_reason, iteration_stats ) + + # Debug logging for request timing + if req_state.stats and iteration_stats: + metrics = req_state.stats + e2e_time = ( + iteration_stats.iteration_timestamp - metrics.arrival_time + ) + queued_time = metrics.scheduled_ts - metrics.queued_ts + prefill_time = metrics.first_token_ts - metrics.scheduled_ts + decode_time = metrics.last_token_ts - metrics.first_token_ts + num_tokens = metrics.num_generation_tokens + tokens_per_sec = ( + num_tokens / decode_time if decode_time > 0 else 0 + ) + logger.debug( + "Request %s: E2E=%.3fs, Queue=%.3fs, " + "Prefill=%.3fs, Decode=%.3fs, " + "Tokens=%d (%.1f tok/s)", + req_state.request_id, + e2e_time, + queued_time, + prefill_time, + decode_time, + num_tokens, + tokens_per_sec, + ) + if self.tracing_enabled: self.do_tracing(engine_core_output, req_state, iteration_stats) diff --git a/vllm/v1/executor/uniproc_executor.py b/vllm/v1/executor/uniproc_executor.py index b616c3b7b8ad..1d646fb8b976 100644 --- a/vllm/v1/executor/uniproc_executor.py +++ b/vllm/v1/executor/uniproc_executor.py @@ -132,6 +132,17 @@ def check_health(self) -> None: # it's running. return + def submit_vision_encoding(self, req_id, mm_features): + """Submit vision encoding for a waiting request to enable pipelining. + + This is called by the Vision Scheduler to proactively start vision processing + for requests that are waiting in the queue (not yet scheduled for LLM). + """ + # Direct call to model_runner for UniProcExecutor (no RPC needed) + if hasattr(self.driver_worker, "model_runner"): + self.driver_worker.model_runner.submit_vision_encoding(req_id, mm_features) + return None + def shutdown(self) -> None: if worker := self.driver_worker: worker.shutdown() diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index d4c5ed273b74..582a5b370d82 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -218,6 +218,14 @@ logger = init_logger(__name__) + +# Custom exception for hybrid NPU+GPU pipelining +class VisionNotReadyError(Exception): + """Raised when vision encoding is not ready in hybrid pipelining mode.""" + + pass + + AttnMetadataDict: TypeAlias = dict[str, AttentionMetadata] # list when ubatching is enabled PerLayerAttnMetadata: TypeAlias = list[AttnMetadataDict] | AttnMetadataDict @@ -624,6 +632,7 @@ def __init__( ) self._init_block_sizes = [placeholder_block_size] self._init_kernel_block_sizes = [placeholder_block_size] + self.input_batch = InputBatch( max_num_reqs=self.max_num_reqs, # We need to use the encoder length for encoder-decoder @@ -2735,6 +2744,316 @@ def _batch_mm_inputs_from_scheduler( return mm_hashes, mm_kwargs, mm_lora_refs + def submit_vision_encoding(self, req_id, mm_features) -> None: + """Submit vision encoding for a waiting request (called by Vision Scheduler). + + This enables pipelining by starting vision processing for requests that are + waiting in the queue (not yet scheduled for LLM execution). + """ + enable_preencoding = ( + envs.VLLM_NPU_ASYNC_PIPELINE + and envs.VLLM_VISION_NPU_BACKEND.lower() == "flexmlrt" + ) + + if not enable_preencoding: + return + + # Initialize thread pool if not already done + if not hasattr(self, "_vision_preencoding_pool"): + from concurrent.futures import ThreadPoolExecutor + + self._vision_preencoding_pool = ThreadPoolExecutor( + max_workers=1, thread_name_prefix="vision_preenc" + ) + logger.info_once( + "[Model Runner] Vision pre-encoding thread pool initialized" + ) + + from vllm.v1.engine.core import _VISION_PREENCODING_CACHE + + # Check if already in cache + if req_id in _VISION_PREENCODING_CACHE: + return + + # Check if request has vision features + if not mm_features: + return + + logger.debug( + "[Model Runner] Submitting vision pre-encoding for request %s", req_id + ) + + # Create encoder_input_ids (use placeholder [0] for waiting requests) + encoder_input_ids = [0] + + # Submit to background thread + future = self._vision_preencoding_pool.submit( + self._encode_single_request_vision, req_id, encoder_input_ids, mm_features + ) + + _VISION_PREENCODING_CACHE[req_id] = future + + def _start_vision_preencoding(self, scheduler_output: "SchedulerOutput"): + """ + Start vision encoding in background thread for new encoder inputs. + + This is called at the beginning of execute_model() to start NPU vision + encoding ASAP, allowing it to overlap with previous request's LLM processing. + """ + enable_preencoding = ( + envs.VLLM_NPU_ASYNC_PIPELINE + and envs.VLLM_VISION_NPU_BACKEND.lower() == "flexmlrt" + ) + + if not enable_preencoding: + return + + # Initialize thread pool if not already done + if not hasattr(self, "_vision_preencoding_pool"): + from concurrent.futures import ThreadPoolExecutor + + self._vision_preencoding_pool = ThreadPoolExecutor( + max_workers=1, thread_name_prefix="vision_preenc" + ) + logger.info_once( + "[NPU Pre-encoding] Vision pre-encoding thread pool initialized" + ) + + from datetime import datetime + + from vllm.v1.engine.core import _VISION_PREENCODING_CACHE + + # Build a mapping of req_id -> mm_features from scheduled_new_reqs + req_id_to_mm_features = {} + for new_req in scheduler_output.scheduled_new_reqs: + if new_req.mm_features: + req_id_to_mm_features[new_req.req_id] = new_req.mm_features + + # Start encoding for scheduled requests + if scheduler_output.scheduled_encoder_inputs: + scheduled_req_ids = list(scheduler_output.scheduled_encoder_inputs.keys()) + logger.debug( + "[NPU Pre-encoding] Scheduled encoder inputs for: %s", + scheduled_req_ids, + ) + # scheduled_encoder_inputs is a dict: {req_id: [encoder_input_ids]} + for ( + req_id, + encoder_input_ids, + ) in scheduler_output.scheduled_encoder_inputs.items(): + # Skip if already encoding OR already encoded + if req_id in _VISION_PREENCODING_CACHE: + cached_value = _VISION_PREENCODING_CACHE[req_id] + if cached_value == "COMPLETED": + # Already encoded, skip + logger.debug( + "[NPU Pre-encoding] SCHEDULED Request %s: SKIP - completed", + req_id, + ) + continue + else: + # Future object - already in progress, skip + logger.debug( + "[NPU Pre-encoding] SCHEDULED Request %s: SKIP - in prog", + req_id, + ) + continue + + # Get mm_features for this request from scheduled_new_reqs + if req_id not in req_id_to_mm_features: + # This shouldn't happen, but skip if no mm_features found + logger.warning( + "[NPU Pre-encoding] Request %s: No mm_features found", + req_id, + ) + continue + + mm_features = req_id_to_mm_features[req_id] + + logger.debug( + "[NPU Pre-encoding] SCHEDULED Request %s: Submitting at %s", + req_id, + datetime.now().strftime("%H:%M:%S.%f")[:-3], + ) + + # Submit to background thread + # Pass encoder_input_ids and mm_features directly + future = self._vision_preencoding_pool.submit( + self._encode_single_request_vision, + req_id, + encoder_input_ids, + mm_features, + ) + + _VISION_PREENCODING_CACHE[req_id] = future + + # Start encoding for IMMEDIATE requests (highest priority) + immediate_keys = [ + k for k in _VISION_PREENCODING_CACHE if k.startswith("immediate_") + ] + for immediate_key in immediate_keys: + request = _VISION_PREENCODING_CACHE[immediate_key] + req_id = immediate_key.replace("immediate_", "") + + # Check if not already processing or completed + if req_id in _VISION_PREENCODING_CACHE: + cached_value = _VISION_PREENCODING_CACHE[req_id] + if cached_value == "COMPLETED": + logger.debug( + "[NPU Pre-encoding] IMMEDIATE Request %s: SKIPPING - completed", + req_id, + ) + del _VISION_PREENCODING_CACHE[immediate_key] + continue + # Already started (Future object), remove the immediate marker + logger.debug( + "[NPU Pre-encoding] IMMEDIATE Request %s: SKIPPING - in progress", + req_id, + ) + del _VISION_PREENCODING_CACHE[immediate_key] + continue + + # Extract encoder_input_ids + encoder_input_ids = list(range(len(request.mm_features))) + + logger.debug( + "[NPU Pre-encoding] IMMEDIATE Request %s: Submitting at %s", + req_id, + datetime.now().strftime("%H:%M:%S.%f")[:-3], + ) + + # Submit to background thread + future = self._vision_preencoding_pool.submit( + self._encode_single_request_vision, + req_id, + encoder_input_ids, + request.mm_features, + ) + + _VISION_PREENCODING_CACHE[req_id] = future + # Clean up the immediate marker + del _VISION_PREENCODING_CACHE[immediate_key] + + # Start encoding for WAITING requests (for pipelining) + # Check if EngineCore marked any waiting requests + waiting_keys = [ + k for k in _VISION_PREENCODING_CACHE if k.startswith("waiting_") + ] + for waiting_key in waiting_keys: + request = _VISION_PREENCODING_CACHE[waiting_key] + req_id = waiting_key.replace("waiting_", "") + + # Check if not already processing or completed + if req_id in _VISION_PREENCODING_CACHE: + cached_value = _VISION_PREENCODING_CACHE[req_id] + if cached_value == "COMPLETED": + logger.debug( + "[NPU Pre-encoding] WAITING Request %s: SKIPPING - completed", + req_id, + ) + del _VISION_PREENCODING_CACHE[waiting_key] + continue + # Already in progress (Future object) + logger.debug( + "[NPU Pre-encoding] WAITING Request %s: SKIPPING - in progress", + req_id, + ) + del _VISION_PREENCODING_CACHE[waiting_key] + continue + + # Extract encoder_input_ids (typically [0] for first vision input) + encoder_input_ids = list(range(len(request.mm_features))) + + logger.debug( + "[NPU Pre-encoding] WAITING Request %s: Submitting at %s", + req_id, + datetime.now().strftime("%H:%M:%S.%f")[:-3], + ) + + # Submit to background thread + future = self._vision_preencoding_pool.submit( + self._encode_single_request_vision, + req_id, + encoder_input_ids, + request.mm_features, + ) + + _VISION_PREENCODING_CACHE[req_id] = future + # Clean up the waiting marker + del _VISION_PREENCODING_CACHE[waiting_key] + + def _encode_single_request_vision( + self, req_id: str, encoder_input_ids: list[int], mm_features: list + ): + """Encode vision for a single request in background thread. + + Args: + req_id: Request ID + encoder_input_ids: List of multimodal input indices for this request + mm_features: List of MultiModalFeatureSpec from scheduler_output + + Returns: + List of encoded vision embeddings (one per encoder_input_id) + """ + from datetime import datetime + + logger.debug( + "[NPU Pre-encoding] Request %s: Vision encoding STARTED at %s", + req_id, + datetime.now().strftime("%H:%M:%S.%f")[:-3], + ) + + try: + # Extract mm_kwargs for the encoder_input_ids + # This mirrors what _batch_mm_inputs_from_scheduler() does + mm_kwargs_list = [] + for mm_input_id in encoder_input_ids: + mm_feature = mm_features[mm_input_id] + if mm_feature.data is not None: + mm_kwargs_list.append((mm_feature.modality, mm_feature.data)) + + if not mm_kwargs_list: + logger.warning( + "[NPU Pre-encoding] Request %s: No valid multimodal data found", + req_id, + ) + return [] + + # Batch the multimodal kwargs using the same logic as _execute_mm_encoder + from vllm.multimodal.utils import group_and_batch_mm_kwargs + + batches = list( + group_and_batch_mm_kwargs( + mm_kwargs_list, + device=self.device, + pin_memory=self.pin_memory, + ) + ) + + # Use the model's vision encoder + model = cast(SupportsMultiModal, self.model) + + # Process each batch (typically just one batch for single request) + encoder_outputs: list[torch.Tensor] = [] + for modality, num_items, mm_kwargs_batch in batches: + # Call embed_multimodal with properly batched inputs + batch_outputs = model.embed_multimodal(**mm_kwargs_batch) + encoder_outputs.extend(batch_outputs) + + logger.debug( + "[NPU Pre-encoding] Request %s: Vision COMPLETED at %s (%d embeddings)", + req_id, + datetime.now().strftime("%H:%M:%S.%f")[:-3], + len(encoder_outputs), + ) + + return encoder_outputs + except Exception as e: + logger.exception( + "[NPU Pre-encoding] Request %s: Vision encoding FAILED: %s", req_id, e + ) + raise + def _execute_mm_encoder( self, scheduler_output: "SchedulerOutput" ) -> list[torch.Tensor]: @@ -2745,6 +3064,184 @@ def _execute_mm_encoder( if not mm_kwargs: return [] + # Log which requests are being encoded + from datetime import datetime + + req_ids = ( + list(scheduler_output.scheduled_encoder_inputs.keys()) + if scheduler_output.scheduled_encoder_inputs + else [] + ) + logger.debug( + "[NPU Pre-encoding] _execute_mm_encoder called at %s for requests: %s", + datetime.now().strftime("%H:%M:%S.%f")[:-3], + req_ids, + ) + + # Check for pre-encoded vision embeddings from background thread + enable_preencoding = ( + envs.VLLM_NPU_ASYNC_PIPELINE + and envs.VLLM_VISION_NPU_BACKEND.lower() == "flexmlrt" + ) + + if enable_preencoding: + # Import global cache + from vllm.v1.engine.core import _VISION_PREENCODING_CACHE + + # Try to get pre-encoded embeddings for this batch + # scheduled_encoder_inputs is a dict: {req_id: [encoder_input_ids]} + req_ids = list(scheduler_output.scheduled_encoder_inputs.keys()) + + # Split requests into pre-encoded and non-pre-encoded + preencoded_req_ids = [] + non_preencoded_req_ids = [] + + for req_id in req_ids: + if req_id in _VISION_PREENCODING_CACHE: + cached_value = _VISION_PREENCODING_CACHE[req_id] + if cached_value == "COMPLETED": + # Already completed - this shouldn't happen but log it + logger.warning( + "[NPU Pre-encoding] Request %s: COMPLETED but re-encoding", + req_id, + ) + non_preencoded_req_ids.append(req_id) + else: + # Future object - pre-encoding in progress or completed + preencoded_req_ids.append(req_id) + else: + non_preencoded_req_ids.append(req_id) + + # Process pre-encoded requests first + preencoded_outputs = [] + preencoded_hashes = [] + + # Check which requests have vision ready + ready_req_ids = [] + not_ready_req_ids = [] + + for req_id in preencoded_req_ids: + future = _VISION_PREENCODING_CACHE[req_id] + + from datetime import datetime + + # Check if vision encoding is complete (non-blocking) + if not future.done(): + # Vision still in progress + logger.debug( + "[NPU Pre-encoding] Request %s: Vision NOT ready yet at %s", + req_id, + datetime.now().strftime("%H:%M:%S.%f")[:-3], + ) + not_ready_req_ids.append(req_id) + else: + ready_req_ids.append(req_id) + + # Process ready requests + for req_id in ready_req_ids: + future = _VISION_PREENCODING_CACHE[req_id] + + from datetime import datetime + + logger.debug( + "[NPU Pre-encoding] Request %s: Vision ready, getting result at %s", + req_id, + datetime.now().strftime("%H:%M:%S.%f")[:-3], + ) + + # Vision is complete, get result (won't block) + embeddings_list = future.result() + + logger.debug( + "[NPU Pre-encoding] Request %s: Got vision at %s (%d embeddings)", + req_id, + datetime.now().strftime("%H:%M:%S.%f")[:-3], + len(embeddings_list), + ) + + # Find corresponding mm_hash for this request + req_idx = req_ids.index(req_id) + preencoded_hashes.append(mm_hashes[req_idx]) + preencoded_outputs.extend(embeddings_list) + + # Replace the Future with completion marker (prevents re-encoding) + _VISION_PREENCODING_CACHE[req_id] = "COMPLETED" + logger.debug( + "[NPU Pre-encoding] Request %s: Marked as COMPLETED in cache", + req_id, + ) + + # For not-ready requests in hybrid pipelining mode: mark in cache + # Allows the scheduler to move on while NPU processes this request + if not_ready_req_ids: + enable_hybrid_pipeline = ( + envs.VLLM_NPU_ASYNC_PIPELINE + and envs.VLLM_VISION_NPU_BACKEND.lower() == "flexmlrt" + ) + if enable_hybrid_pipeline: + logger.debug( + "[NPU Pre-encoding] %d requests NOT ready, marking: %s", + len(not_ready_req_ids), + not_ready_req_ids, + ) + # Mark these requests' mm_hashes as NOT_READY in encoder cache + # This prevents _gather_mm_embeddings from asserting + for req_id in not_ready_req_ids: + req_idx = req_ids.index(req_id) + self.encoder_cache[mm_hashes[req_idx]] = "NOT_READY" + # Don't add to non_preencoded - skip synchronous encoding + else: + # Legacy behavior: block until vision completes + logger.warning( + "[NPU Pre-encoding] %d requests not ready, will BLOCK: %s", + len(not_ready_req_ids), + not_ready_req_ids, + ) + non_preencoded_req_ids.extend(not_ready_req_ids) + + # Cache pre-encoded embeddings + if preencoded_outputs: + logger.debug( + "[NPU Pre-encoding] Using pre-encoded: %d embeddings, %d reqs", + len(preencoded_outputs), + len(preencoded_req_ids), + ) + for mm_hash, output in zip(preencoded_hashes, preencoded_outputs): + self.encoder_cache[mm_hash] = output + logger.debug("Using pre-encoded vision for mm hash %s", mm_hash) + self.maybe_save_ec_to_connector(self.encoder_cache, mm_hash) + + # If ALL requests were pre-encoded, return early + if not non_preencoded_req_ids: + return preencoded_outputs + + # Otherwise, we need to process non-pre-encoded requests below + # Filter mm_kwargs and mm_hashes to only include non-pre-encoded requests + logger.debug( + "[NPU Pre-encoding] %d requests need synchronous encoding", + len(non_preencoded_req_ids), + ) + + # Rebuild mm_kwargs and mm_hashes for only non-pre-encoded requests + non_preencoded_mm_kwargs = [] + non_preencoded_mm_hashes = [] + non_preencoded_mm_lora_refs = [] + + for req_id in non_preencoded_req_ids: + req_idx = req_ids.index(req_id) + non_preencoded_mm_kwargs.append(mm_kwargs[req_idx]) + non_preencoded_mm_hashes.append(mm_hashes[req_idx]) + # mm_lora_refs needs to be filtered carefully + # It's a list of tuples (req_id, pos_info), filter by req_id + for lora_ref in mm_lora_refs: + if lora_ref[0] == req_id: + non_preencoded_mm_lora_refs.append(lora_ref) + + # Update variables for the rest of the function + mm_kwargs = non_preencoded_mm_kwargs + mm_hashes = non_preencoded_mm_hashes + mm_lora_refs = non_preencoded_mm_lora_refs + should_time = bool( self.observability_config and self.observability_config.enable_mm_processor_stats @@ -2829,14 +3326,135 @@ def _execute_mm_encoder( connector_mapping, ) + # Check if NPU backend and parallel processing enabled + using_npu = envs.VLLM_VISION_NPU_BACKEND.lower() == "flexmlrt" + enable_parallel = using_npu and envs.VLLM_NPU_ASYNC_PIPELINE + + # Collect all batches from the generator first + batches = list( + group_and_batch_mm_kwargs( + mm_kwargs, + device=self.device, + pin_memory=self.pin_memory, + ) + ) + + # DEBUG: Log batch count + if envs.VLLM_NPU_TIMING: + import threading + from datetime import datetime + + timestamp = datetime.now().strftime("%H:%M:%S.%f")[:-3] + logger.info( + "[GPU Model Runner] %s T-%d: npu=%s, par=%s, batches=%d, kwargs=%d", + timestamp, + threading.get_ident(), + using_npu, + enable_parallel, + len(batches), + len(mm_kwargs), + ) + encoder_outputs: list[torch.Tensor] = [] + + # NPU parallel processing path + if enable_parallel and len(batches) > 1: + from concurrent.futures import ThreadPoolExecutor + + if envs.VLLM_NPU_TIMING: + import threading + from datetime import datetime + + timestamp = datetime.now().strftime("%H:%M:%S.%f")[:-3] + logger.info( + "[GPU Model Runner] %s Thread-%d: %d batches PARALLEL", + timestamp, + threading.get_ident(), + len(batches), + ) + + def process_batch_item(batch_info): + modality, num_items, mm_kwargs_batch = batch_info + if envs.VLLM_NPU_TIMING: + import threading + from datetime import datetime + + start = datetime.now().strftime("%H:%M:%S.%f")[:-3] + logger.info( + "[GPU Model Runner] %s Thread-%d: Starting %s (items=%d)", + start, + threading.get_ident(), + modality, + num_items, + ) + + # Call embed_multimodal for this batch + batch_outputs = model.embed_multimodal(**mm_kwargs_batch) + sanity_check_mm_encoder_outputs( + batch_outputs, expected_num_items=num_items + ) + + if envs.VLLM_NPU_TIMING: + import threading + from datetime import datetime + + end = datetime.now().strftime("%H:%M:%S.%f")[:-3] + logger.info( + "[GPU Model Runner] %s Thread-%d: Finished batch for %s", + end, + threading.get_ident(), + modality, + ) + + return batch_outputs + + # Process all batches in parallel + with ThreadPoolExecutor( + max_workers=len(batches), thread_name_prefix="vision_parallel" + ) as executor: + futures = [ + executor.submit(process_batch_item, batch) for batch in batches + ] + batch_results = [f.result() for f in futures] + + # Flatten results + for outputs in batch_results: + encoder_outputs.extend(outputs) + + # Cache the encoder outputs by mm_hash + for mm_hash, output in zip(mm_hashes, encoder_outputs): + self.encoder_cache[mm_hash] = output + logger.debug("Finish execute for mm hash %s", mm_hash) + self.maybe_save_ec_to_connector(self.encoder_cache, mm_hash) + + if envs.VLLM_NPU_TIMING: + from datetime import datetime + + timestamp = datetime.now().strftime("%H:%M:%S.%f")[:-3] + logger.info( + "[GPU Model Runner] %s: All %d batches completed in parallel", + timestamp, + len(batches), + ) + + return encoder_outputs + + # Standard sequential path (GPU or single batch or parallel disabled) + if envs.VLLM_NPU_TIMING and batches: + import threading + from datetime import datetime + + timestamp = datetime.now().strftime("%H:%M:%S.%f")[:-3] + logger.info( + "[GPU Model Runner] %s Thread-%d: Processing %d batches SEQUENTIALLY", + timestamp, + threading.get_ident(), + len(batches), + ) + # Track the current index in mm_kwargs/mm_lora_refs to map groups to request IDs current_item_idx = 0 - for modality, num_items, mm_kwargs_batch in group_and_batch_mm_kwargs( - mm_kwargs, - device=self.device, - pin_memory=self.pin_memory, - ): + for modality, num_items, mm_kwargs_batch in batches: batch_outputs: MultiModalEmbeddings # EVS and dynamic res video related change. @@ -2917,6 +3535,37 @@ def _execute_mm_encoder( logger.debug("Finish execute for mm hash %s", mm_hash) self.maybe_save_ec_to_connector(self.encoder_cache, mm_hash) + # Combine pre-encoded outputs (if any) with synchronously encoded outputs + if ( + enable_preencoding + and "preencoded_outputs" in locals() + and preencoded_outputs + ): + # Merge the two lists in correct order based on original req_ids + all_req_ids = list(scheduler_output.scheduled_encoder_inputs.keys()) + combined_outputs = [] + + preencoded_idx = 0 + sync_idx = 0 + + for req_id in all_req_ids: + if req_id in preencoded_req_ids: + # This request was pre-encoded + combined_outputs.append(preencoded_outputs[preencoded_idx]) + preencoded_idx += 1 + else: + # This request was synchronously encoded + combined_outputs.append(encoder_outputs[sync_idx]) + sync_idx += 1 + + logger.debug( + "[NPU Pre-encoding] Combined %d pre-encoded + %d sync = %d total", + len(preencoded_outputs), + len(encoder_outputs), + len(combined_outputs), + ) + return combined_outputs + return encoder_outputs def _gather_mm_embeddings( @@ -2975,6 +3624,19 @@ def _gather_mm_embeddings( mm_hash = mm_feature.identifier encoder_output = self.encoder_cache.get(mm_hash, None) + + # Check if vision encoding is still in progress (hybrid pipelining mode) + if encoder_output == "NOT_READY": + enable_hybrid_pipeline = ( + envs.VLLM_NPU_ASYNC_PIPELINE + and envs.VLLM_VISION_NPU_BACKEND.lower() == "flexmlrt" + ) + if enable_hybrid_pipeline: + # Raise exception to signal execute_model should skip request + raise VisionNotReadyError( + f"Vision not ready for req {req_id}, mm_hash {mm_hash}" + ) + assert encoder_output is not None, f"Encoder cache miss for {mm_hash}." if (is_embed := pos_info.is_embed) is not None: @@ -3236,12 +3898,21 @@ def _preprocess( if self.supports_mm_inputs and is_first_rank and not is_encoder_decoder: # Run the multimodal encoder if any. - with self.maybe_get_ec_connector_output( - scheduler_output, - encoder_cache=self.encoder_cache, - ) as ec_connector_output: - self._execute_mm_encoder(scheduler_output) - mm_embeds, is_mm_embed = self._gather_mm_embeddings(scheduler_output) + try: + with self.maybe_get_ec_connector_output( + scheduler_output, + encoder_cache=self.encoder_cache, + ) as ec_connector_output: + self._execute_mm_encoder(scheduler_output) + mm_embeds, is_mm_embed = self._gather_mm_embeddings( + scheduler_output + ) + except VisionNotReadyError as e: + # Vision not ready - return None to signal scheduler to skip step + logger.debug( + "[Hybrid Pipelining] Vision not ready: %s - returning None", e + ) + return None, None, None, None, {}, None # NOTE(woosuk): To unify token ids and soft tokens (vision # embeddings), we always use embeddings (rather than token ids) @@ -3830,6 +4501,9 @@ def execute_model( scheduler_output: "SchedulerOutput", intermediate_tensors: IntermediateTensors | None = None, ) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors | None: + # Start vision pre-encoding in background for any new encoder inputs + self._start_vision_preencoding(scheduler_output) + if self.execute_model_state is not None: raise RuntimeError( "State error: sample_tokens() must be called " @@ -4040,6 +4714,17 @@ def execute_model( ) ) + preprocess_result = self._preprocess( + scheduler_output, num_tokens_padded, intermediate_tensors + ) + + # Check if vision encoding was not ready (hybrid pipelining mode) + if preprocess_result == (None, None, None, None, {}, None): + logger.info( + "[Hybrid Pipelining] Vision not ready, returning EMPTY_OUTPUT" + ) + return EMPTY_MODEL_RUNNER_OUTPUT + ( input_ids, inputs_embeds, @@ -4047,9 +4732,7 @@ def execute_model( intermediate_tensors, model_kwargs, ec_connector_output, - ) = self._preprocess( - scheduler_output, num_tokens_padded, intermediate_tensors - ) + ) = preprocess_result # Set cudagraph mode to none if calc_kv_scales is true. # KV scales calculation involves dynamic operations that are incompatible diff --git a/vllm/vision_npu/flexmlrt_backend.py b/vllm/vision_npu/flexmlrt_backend.py index 9e01d65af206..9cf3e6d910ea 100644 --- a/vllm/vision_npu/flexmlrt_backend.py +++ b/vllm/vision_npu/flexmlrt_backend.py @@ -9,9 +9,11 @@ execution, matching the behavior of VitisAI ExecutionProvider. """ +import asyncio import contextlib import logging import time +from concurrent.futures import ThreadPoolExecutor import numpy as np import torch @@ -25,6 +27,7 @@ # Cache environment variables for performance (avoids repeated lookups) VLLM_NPU_TIMING = envs.VLLM_NPU_TIMING +VLLM_NPU_ASYNC_PIPELINE = envs.VLLM_NPU_ASYNC_PIPELINE @contextlib.contextmanager @@ -140,3 +143,200 @@ def forward(self, pixel_values: np.ndarray, grid_thw: np.ndarray) -> np.ndarray: def output_dim(self) -> int: """Get output embedding dimension from FlexMLRT model.""" return self.model.output_dim() + + +class AsyncFlexMLRTVisionBackend: + """Async wrapper for FlexMLRT backend enabling NPU+GPU pipelining. + + Allows NPU vision processing for request N+1 to overlap with GPU LLM + processing for request N, improving throughput for multi-request workloads. + + Example throughput improvement: + - Sequential: Request1(NPU 13.5s + GPU 20s) → Request2(NPU 13.5s + GPU + 20s) = 67s for 2 requests + - Pipelined: Request1(NPU 13.5s) → overlap(NPU 13.5s for Req2 || GPU 20s + for Req1) → GPU 20s for Req2 = 47s for 2 requests + - Speedup: 1.43x for 2 requests, approaches 1.5x+ for longer sequences + """ + + def __init__(self, model_cache_path: str, device_name: str = "stx"): + """Initialize async wrapper with underlying synchronous backend. + + Args: + model_cache_path: Path to VAIP model cache (vaiml_par_0 directory) + device_name: XRT device name ("stx" for Strix, "phx" for Phoenix) + """ + # Underlying synchronous backend + self.sync_backend = FlexMLRTVisionBackend(model_cache_path, device_name) + + # Thread pool for NPU inference (separate from GPU thread) + # Single worker ensures NPU executes one request at a time + self.npu_executor = ThreadPoolExecutor( + max_workers=1, thread_name_prefix="npu_vision" + ) + + # Stats for monitoring + self.npu_queue_size = 0 + self.total_requests = 0 + + if VLLM_NPU_ASYNC_PIPELINE: + logger.info( + "[Async FlexMLRT Backend] Initialized with async pipelining enabled" + ) + else: + logger.info( + "[Async FlexMLRT Backend] Initialized " + "(async disabled, use VLLM_NPU_ASYNC_PIPELINE=1)" + ) + + async def forward_async( + self, pixel_values: np.ndarray, grid_thw: np.ndarray + ) -> np.ndarray: + """Async version that enables NPU-GPU pipelining. + + Submits NPU work to a dedicated executor, allowing it to run concurrently + with GPU work from other requests. + + Args: + pixel_values: [seq_len, feature_dim] float32 array from HF processor + grid_thw: [num_images, 3] int64 array + + Returns: + embeddings: [merged_seq_len, hidden_dim] float32 array + """ + loop = asyncio.get_event_loop() + + self.npu_queue_size += 1 + self.total_requests += 1 + request_id = self.total_requests + + if VLLM_NPU_TIMING: + logger.info( + "[Async NPU] Request %s submitted to NPU queue (queue size: %s)", + request_id, + self.npu_queue_size, + ) + + try: + # Submit to NPU executor (non-blocking from caller's perspective) + # This allows GPU to continue processing previous requests while NPU works + result = await loop.run_in_executor( + self.npu_executor, self.sync_backend.forward, pixel_values, grid_thw + ) + + if VLLM_NPU_TIMING: + logger.info( + "[Async NPU] Request %s completed NPU processing", request_id + ) + + return result + finally: + self.npu_queue_size -= 1 + + def forward(self, pixel_values: np.ndarray, grid_thw: np.ndarray) -> np.ndarray: + """Synchronous interface with async execution underneath. + + Submits work to NPU executor thread, allowing multiple requests to pipeline. + This blocks the caller until NPU processing completes, but allows other + threads (e.g., GPU LLM processing) to run concurrently. + """ + import threading + from datetime import datetime + + self.npu_queue_size += 1 + self.total_requests += 1 + request_id = self.total_requests + + submit_time = datetime.now().strftime("%H:%M:%S.%f")[:-3] + caller_thread = threading.get_ident() + + if VLLM_NPU_TIMING: + logger.info( + "[Async NPU Pipeline] Request %s SUBMITTED at %s by Thread-%s " + "(queue size: %s)", + request_id, + submit_time, + caller_thread, + self.npu_queue_size, + ) + + try: + # Submit to executor - allows pipelining with GPU work from + # other requests + if VLLM_NPU_TIMING: + logger.info( + "[Async NPU Pipeline] Request %s submitting to ThreadPoolExecutor " + "(queue size before: %s)", + request_id, + self.npu_queue_size, + ) + + future = self.npu_executor.submit( + self._forward_with_timing, pixel_values, grid_thw, request_id + ) + + if VLLM_NPU_TIMING: + logger.info( + "[Async NPU Pipeline] Request %s future created, " + "now waiting for result...", + request_id, + ) + + # Block until NPU processing completes + result = future.result() + + complete_time = datetime.now().strftime("%H:%M:%S.%f")[:-3] + if VLLM_NPU_TIMING: + logger.info( + "[Async NPU Pipeline] Request %s COMPLETED at %s on Thread-%s", + request_id, + complete_time, + caller_thread, + ) + + return result + finally: + self.npu_queue_size -= 1 + + def _forward_with_timing( + self, pixel_values: np.ndarray, grid_thw: np.ndarray, request_id: int + ) -> np.ndarray: + """Internal forward with NPU start/end timing.""" + import threading + from datetime import datetime + + worker_thread = threading.get_ident() + npu_start_time = datetime.now().strftime("%H:%M:%S.%f")[:-3] + + if VLLM_NPU_TIMING: + logger.info( + "[Async NPU Pipeline] Request %s NPU STARTED at %s on " + "NPU-Worker-Thread-%s", + request_id, + npu_start_time, + worker_thread, + ) + + result = self.sync_backend.forward(pixel_values, grid_thw) + + npu_end_time = datetime.now().strftime("%H:%M:%S.%f")[:-3] + if VLLM_NPU_TIMING: + logger.info( + "[Async NPU Pipeline] Request %s NPU FINISHED at %s on " + "NPU-Worker-Thread-%s", + request_id, + npu_end_time, + worker_thread, + ) + + return result + + @property + def output_dim(self) -> int: + """Get output embedding dimension from FlexMLRT model.""" + return self.sync_backend.output_dim + + def __del__(self): + """Cleanup thread pool on deletion.""" + if hasattr(self, "npu_executor"): + self.npu_executor.shutdown(wait=True) From 8d16afac9f0caf2f41d4dbc7e1ca8a243e0e5c14 Mon Sep 17 00:00:00 2001 From: lichang Date: Mon, 1 Jun 2026 16:08:30 -0600 Subject: [PATCH 3/3] refactor: simplify NPU vision env and CPU preprocessor layout Use VLLM_VISION_NPU_CACHE as the sole enable switch instead of VLLM_VISION_NPU_BACKEND. Move Qwen2.5-VL CPU preprocessing into vision_npu/models/ and drop the unoptimized numpy path. Keep AsyncFlexMLRTVisionBackend for pipelining when VLLM_NPU_ASYNC_PIPELINE=1. Co-authored-by: Cursor --- vllm/envs.py | 5 +- vllm/model_executor/models/qwen2_5_vl.py | 2 +- vllm/model_executor/models/vision.py | 47 +-- vllm/multimodal/utils.py | 5 +- vllm/v1/core/sched/scheduler.py | 2 +- vllm/v1/engine/core.py | 2 +- vllm/v1/worker/gpu_model_runner.py | 12 +- vllm/vision_npu/cpu_preprocess.py | 276 +----------------- vllm/vision_npu/flexmlrt_backend.py | 2 +- vllm/vision_npu/models/__init__.py | 4 + .../models/qwen2_5_vl_cpu_preprocess.py | 93 ++++++ 11 files changed, 136 insertions(+), 314 deletions(-) create mode 100644 vllm/vision_npu/models/__init__.py create mode 100644 vllm/vision_npu/models/qwen2_5_vl_cpu_preprocess.py diff --git a/vllm/envs.py b/vllm/envs.py index 9d62308df601..2558c73497bb 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -212,7 +212,6 @@ VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: int | None = None VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 480 VLLM_MORIIO_CONNECTOR_READ_MODE: bool = False - VLLM_VISION_NPU_BACKEND: str = "" VLLM_VISION_NPU_CACHE: str | None = None VLLM_VISION_NPU_DEVICE: str | None = None VLLM_NPU_ASYNC_PIPELINE: bool = False @@ -1749,9 +1748,7 @@ def _get_or_set_default() -> str: # Disable PDL for LoRA, as enabling PDL with LoRA on SM100 causes # Triton compilation to fail. "VLLM_LORA_DISABLE_PDL": lambda: bool(int(os.getenv("VLLM_LORA_DISABLE_PDL", "0"))), - # NPU vision backend to use (e.g., "flexmlrt" for FlexMLRT backend) - "VLLM_VISION_NPU_BACKEND": lambda: os.getenv("VLLM_VISION_NPU_BACKEND", ""), - # Path to NPU model cache directory (required for FlexMLRT backend) + # Path to NPU model cache directory (enables FlexMLRT vision backend when set) "VLLM_VISION_NPU_CACHE": lambda: os.getenv("VLLM_VISION_NPU_CACHE"), # NPU device name (e.g., "stx" for Strix, "phx" for Phoenix) "VLLM_VISION_NPU_DEVICE": lambda: os.getenv("VLLM_VISION_NPU_DEVICE"), diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index df54948c43fa..9bca8e8a5f1a 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -601,7 +601,7 @@ def __init__( logger.error("[Qwen2.5VL] NPU backend init failed: %s", e) raise RuntimeError( f"NPU vision backend initialization failed: {e}. " - "Set VLLM_VISION_NPU_BACKEND='' to use PyTorch backend." + "Unset VLLM_VISION_NPU_CACHE to use PyTorch backend." ) from e self.npu_backend = None diff --git a/vllm/model_executor/models/vision.py b/vllm/model_executor/models/vision.py index 0ab01e8d35a2..4d7e528fffce 100644 --- a/vllm/model_executor/models/vision.py +++ b/vllm/model_executor/models/vision.py @@ -609,53 +609,36 @@ def get_llm_pos_ids_for_vision( def use_npu_vision_backend() -> bool: - """Check if NPU backend is enabled for vision processing. - - Returns: - True if VLLM_VISION_NPU_BACKEND environment variable is set to - a supported backend (flexmlrt), False otherwise. - """ + """Check if NPU vision is enabled via VLLM_VISION_NPU_CACHE.""" import vllm.envs as envs - backend = ( - envs.VLLM_VISION_NPU_BACKEND.lower() if envs.VLLM_VISION_NPU_BACKEND else "" - ) - return backend == "flexmlrt" + return bool(envs.VLLM_VISION_NPU_CACHE) def get_npu_vision_backend(): - """Get NPU vision backend instance if enabled. + """Get FlexMLRT NPU vision backend instance if VLLM_VISION_NPU_CACHE is set. Returns: - NPUVisionBackend instance if NPU backend is enabled, None otherwise. - Returns AsyncFlexMLRTVisionBackend if VLLM_NPU_ASYNC_PIPELINE=1. + FlexMLRTVisionBackend, or AsyncFlexMLRTVisionBackend when + VLLM_NPU_ASYNC_PIPELINE=1. Raises: - ValueError: If backend name is recognized but initialization fails. + ValueError: If VLLM_VISION_NPU_CACHE is set but initialization fails. ImportError: If backend dependencies are not available. """ import vllm.envs as envs - backend_name = ( - envs.VLLM_VISION_NPU_BACKEND.lower() if envs.VLLM_VISION_NPU_BACKEND else "" - ) + model_cache = envs.VLLM_VISION_NPU_CACHE + if not model_cache: + return None - if backend_name == "flexmlrt": - model_cache = envs.VLLM_VISION_NPU_CACHE - if not model_cache: - raise ValueError( - "VLLM_VISION_NPU_CACHE must be set when using FlexMLRT backend" - ) - device_name = envs.VLLM_VISION_NPU_DEVICE or "stx" + device_name = envs.VLLM_VISION_NPU_DEVICE or "stx" - # Use async backend if pipelining is enabled - if envs.VLLM_NPU_ASYNC_PIPELINE: - from vllm.vision_npu.flexmlrt_backend import AsyncFlexMLRTVisionBackend + if envs.VLLM_NPU_ASYNC_PIPELINE: + from vllm.vision_npu.flexmlrt_backend import AsyncFlexMLRTVisionBackend - return AsyncFlexMLRTVisionBackend(model_cache, device_name) - else: - from vllm.vision_npu.flexmlrt_backend import FlexMLRTVisionBackend + return AsyncFlexMLRTVisionBackend(model_cache, device_name) - return FlexMLRTVisionBackend(model_cache, device_name) + from vllm.vision_npu.flexmlrt_backend import FlexMLRTVisionBackend - return None + return FlexMLRTVisionBackend(model_cache, device_name) diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index 0f7146eba3aa..b5ccf3f081fe 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -39,10 +39,7 @@ def _is_npu_vision_backend() -> bool: """ import vllm.envs as envs - backend = ( - envs.VLLM_VISION_NPU_BACKEND.lower() if envs.VLLM_VISION_NPU_BACKEND else "" - ) - return backend == "flexmlrt" + return bool(envs.VLLM_VISION_NPU_CACHE) def encode_audio_base64( diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 6fcc79834a71..9b4cceaccc7d 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -105,7 +105,7 @@ def __init__( self.max_num_running_reqs = self.scheduler_config.max_num_seqs self.enable_hybrid_pipeline = ( envs.VLLM_NPU_ASYNC_PIPELINE - and envs.VLLM_VISION_NPU_BACKEND.lower() == "flexmlrt" + and envs.VLLM_VISION_NPU_CACHE ) # Set during schedule() when a request is deferred for NPU vision. self.waiting_on_vision_encoding = False diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 611395f60843..19a51abd6fdb 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -436,7 +436,7 @@ def _schedule_waiting_vision(self) -> None: """ if not ( envs.VLLM_NPU_ASYNC_PIPELINE - and envs.VLLM_VISION_NPU_BACKEND.lower() == "flexmlrt" + and envs.VLLM_VISION_NPU_CACHE ): return diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 582a5b370d82..0cacf22e0174 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2752,7 +2752,7 @@ def submit_vision_encoding(self, req_id, mm_features) -> None: """ enable_preencoding = ( envs.VLLM_NPU_ASYNC_PIPELINE - and envs.VLLM_VISION_NPU_BACKEND.lower() == "flexmlrt" + and envs.VLLM_VISION_NPU_CACHE ) if not enable_preencoding: @@ -2802,7 +2802,7 @@ def _start_vision_preencoding(self, scheduler_output: "SchedulerOutput"): """ enable_preencoding = ( envs.VLLM_NPU_ASYNC_PIPELINE - and envs.VLLM_VISION_NPU_BACKEND.lower() == "flexmlrt" + and envs.VLLM_VISION_NPU_CACHE ) if not enable_preencoding: @@ -3081,7 +3081,7 @@ def _execute_mm_encoder( # Check for pre-encoded vision embeddings from background thread enable_preencoding = ( envs.VLLM_NPU_ASYNC_PIPELINE - and envs.VLLM_VISION_NPU_BACKEND.lower() == "flexmlrt" + and envs.VLLM_VISION_NPU_CACHE ) if enable_preencoding: @@ -3176,7 +3176,7 @@ def _execute_mm_encoder( if not_ready_req_ids: enable_hybrid_pipeline = ( envs.VLLM_NPU_ASYNC_PIPELINE - and envs.VLLM_VISION_NPU_BACKEND.lower() == "flexmlrt" + and envs.VLLM_VISION_NPU_CACHE ) if enable_hybrid_pipeline: logger.debug( @@ -3327,7 +3327,7 @@ def _execute_mm_encoder( ) # Check if NPU backend and parallel processing enabled - using_npu = envs.VLLM_VISION_NPU_BACKEND.lower() == "flexmlrt" + using_npu = bool(envs.VLLM_VISION_NPU_CACHE) enable_parallel = using_npu and envs.VLLM_NPU_ASYNC_PIPELINE # Collect all batches from the generator first @@ -3629,7 +3629,7 @@ def _gather_mm_embeddings( if encoder_output == "NOT_READY": enable_hybrid_pipeline = ( envs.VLLM_NPU_ASYNC_PIPELINE - and envs.VLLM_VISION_NPU_BACKEND.lower() == "flexmlrt" + and envs.VLLM_VISION_NPU_CACHE ) if enable_hybrid_pipeline: # Raise exception to signal execute_model should skip request diff --git a/vllm/vision_npu/cpu_preprocess.py b/vllm/vision_npu/cpu_preprocess.py index ba6e9d3afbcc..6bc3066dd868 100644 --- a/vllm/vision_npu/cpu_preprocess.py +++ b/vllm/vision_npu/cpu_preprocess.py @@ -1,279 +1,27 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ -CPU preprocessing operations for VitisAI-compiled vision models. - -This module implements the CPU operations that VitisAI ExecutionProvider -normally handles automatically. When using FlexMLRT directly, we must -manually implement these operations. - -For Qwen2.5-VL vision model: -- Input: pixel_values [4292, 1176] from HuggingFace processor -- Output: preprocessed [1073, 4, 1280] ready for NPU -- Postprocessing: Apply reverse_index Gather to NPU output +Factory for model-specific CPU preprocessors used by NPU vision backends. """ -import logging +from __future__ import annotations + +from typing import Protocol import numpy as np import torch -logger = logging.getLogger(__name__) - - -class Qwen2_5_VL_CPUPreprocessor: - """CPU preprocessing for Qwen2.5-VL vision model before NPU execution.""" - - def __init__(self, model_cache_dir: str): - """ - Initialize CPU preprocessor with required parameters. - - Args: - model_cache_dir: Path to NPU model cache directory containing ONNX model - """ - import os - - import onnx - - # Load ONNX model to extract parameters - # model_cache_dir is typically: .../qwen2_5_vl_vision_stitched_7b/vaiml_par_0 - # We need to go up two levels to find the .onnx file - onnx_model_path = os.path.join( - os.path.dirname(os.path.dirname(model_cache_dir)), - "qwen2_5_vl_vision_stitched_7b.onnx", - ) - - if not os.path.exists(onnx_model_path): - logger.warning( - "[CPU Preprocess] ONNX not found at %s, trying alternative path", - onnx_model_path, - ) - # Alternative: look in parent directory - alt_path = os.path.join( - os.path.dirname(model_cache_dir), "qwen2_5_vl_vision_stitched_7b.onnx" - ) - if os.path.exists(alt_path): - onnx_model_path = alt_path - else: - raise FileNotFoundError( - f"Cannot find ONNX model at {onnx_model_path} or {alt_path}" - ) - - logger.info("[CPU Preprocess] Loading ONNX model from %s", onnx_model_path) - model = onnx.load(onnx_model_path) - graph = model.graph - - # Extract parameters from ONNX model - initializers = {init.name: init for init in graph.initializer} - - # Conv weights for patch embedding - if "patch_embed.proj.weight" in initializers: - weight_tensor = initializers["patch_embed.proj.weight"] - self.conv_weight = onnx.numpy_helper.to_array(weight_tensor) - logger.info( - "[CPU Preprocess] Loaded conv weight: %s", self.conv_weight.shape - ) - else: - raise ValueError("patch_embed.proj.weight not found in ONNX model") - - # Gather indices for window reordering - if "blocks.window_index" in initializers: - indices_tensor = initializers["blocks.window_index"] - self.window_index = onnx.numpy_helper.to_array(indices_tensor) - logger.info( - "[CPU Preprocess] Loaded window_index: %s", self.window_index.shape - ) - else: - raise ValueError("blocks.window_index not found in ONNX model") - - # Reverse index for final postprocessing - if "merger.reverse_index" in initializers: - reverse_tensor = initializers["merger.reverse_index"] - self.reverse_index = onnx.numpy_helper.to_array(reverse_tensor) - logger.info( - "[CPU Preprocess] Loaded reverse_index: %s", self.reverse_index.shape - ) - else: - raise ValueError("merger.reverse_index not found in ONNX model") - - logger.info("[CPU Preprocess] Initialized successfully") - - def preprocess(self, pixel_values: torch.Tensor) -> np.ndarray: - """ - Apply CPU preprocessing operations to pixel_values. - - Args: - pixel_values: [seq_len, feature_dim] float32 tensor from HF processor - Expected shape: [4292, 1176] - - Returns: - preprocessed: [1073, 4, 1280] float32 numpy array ready for NPU - """ - # Convert to numpy - if isinstance(pixel_values, torch.Tensor): - pixel_values_np = pixel_values.cpu().float().numpy() - else: - pixel_values_np = pixel_values.astype(np.float32) - - logger.info("[CPU Preprocess] Input shape: %s", pixel_values_np.shape) - - # Operation 1: Reshape to [batch, 3, 2, 14, 14] - # pixel_values [4292, 1176] → [4292, 3, 2, 14, 14] - x = pixel_values_np.reshape(-1, 3, 2, 14, 14) - - # Operation 2: Conv3D for patch embedding - # Input: [4292, 3, 2, 14, 14] - # Weight: [1280, 3, 2, 14, 14] - # Output: [4292, 1280, 1, 1, 1] - out_channels = self.conv_weight.shape[0] - batch_size = x.shape[0] - conv_out = np.zeros((batch_size, out_channels, 1, 1, 1), dtype=np.float32) - - # Naive implementation - can be optimized with torch.nn.functional.conv3d - for b in range(batch_size): - for oc in range(out_channels): - conv_out[b, oc, 0, 0, 0] = np.sum(x[b] * self.conv_weight[oc]) - - # Operation 3: Reshape to [4292, 1280] - x2 = conv_out.reshape(-1, 1280) - - # Operation 4: Reshape to [1073, 4, 1280] - merge patches 4x4 - x3 = x2.reshape(1073, 4, 1280) - - # Operation 5: Gather with window_index (reordering) - # Note: This maintains shape [1073, 4, 1280] - x4 = x3[self.window_index] - - logger.info("[CPU Preprocess] Output shape: %s", x4.shape) - return x4 - - def postprocess(self, npu_output: np.ndarray) -> np.ndarray: - """ - Apply CPU postprocessing to NPU output. - - Args: - npu_output: [1073, 3584] float32 array from NPU - - Returns: - final_output: [1073, 3584] float32 array after reverse_index reordering - """ - # Apply final Gather with reverse_index - reordered = npu_output[self.reverse_index] - logger.info( - "[CPU Postprocess] Applied reverse_index, shape: %s", reordered.shape - ) - return reordered - - -class Qwen2_5_VL_CPUPreprocessor_Optimized: - """Optimized version using torch for Conv3D.""" - - def __init__(self, model_cache_dir: str): - """Initialize with torch-based Conv3D for faster preprocessing.""" - import os - - import onnx - - onnx_model_path = os.path.join( - os.path.dirname(os.path.dirname(model_cache_dir)), - "qwen2_5_vl_vision_stitched_7b.onnx", - ) - - if not os.path.exists(onnx_model_path): - logger.warning( - "[CPU Preprocess Optimized] ONNX not found at %s, trying alternative", - onnx_model_path, - ) - alt_path = os.path.join( - os.path.dirname(model_cache_dir), "qwen2_5_vl_vision_stitched_7b.onnx" - ) - if os.path.exists(alt_path): - onnx_model_path = alt_path - else: - raise FileNotFoundError( - f"Cannot find ONNX model at {onnx_model_path} or {alt_path}" - ) - - logger.info( - "[CPU Preprocess Optimized] Loading ONNX model from %s", onnx_model_path - ) - model = onnx.load(onnx_model_path) - graph = model.graph - initializers = {init.name: init for init in graph.initializer} - - # Load parameters and convert to torch - weight_np = onnx.numpy_helper.to_array(initializers["patch_embed.proj.weight"]) - self.conv_weight = torch.from_numpy(weight_np).float() - - self.window_index = onnx.numpy_helper.to_array( - initializers["blocks.window_index"] - ) - self.reverse_index = onnx.numpy_helper.to_array( - initializers["merger.reverse_index"] - ) - - # Release ONNX model from memory (saves ~600 MB CPU RAM) - del model, graph, initializers, weight_np - import gc - - gc.collect() - logger.info( - "[CPU Preprocess Optimized] Initialized with torch Conv3D " - "(ONNX model released from memory)" - ) - - def preprocess(self, pixel_values: torch.Tensor) -> np.ndarray: - """Optimized preprocessing using torch.nn.functional.conv3d.""" - pixel_values = pixel_values.cpu().float() - - # Reshape to [batch, 3, 2, 14, 14] - x = pixel_values.reshape(-1, 3, 2, 14, 14) - - # Conv3D using torch (much faster than numpy) - import torch.nn.functional as F - - # Rearrange to [batch, channels, depth, height, width] - conv_out = F.conv3d( - x, self.conv_weight, bias=None, stride=(2, 14, 14), padding=(0, 0, 0) - ) # Output: [4292, 1280, 1, 1, 1] - - # Reshape to [4292, 1280] - x2 = conv_out.reshape(-1, 1280) - - # Reshape to [1073, 4, 1280] - x3 = x2.reshape(1073, 4, 1280) - - # Gather with window_index - x4_np = x3.numpy()[self.window_index] +from vllm.vision_npu.models.qwen2_5_vl_cpu_preprocess import Qwen2_5_VLCpuPreprocessor - logger.info("[CPU Preprocess Optimized] Output shape: %s", x4_np.shape) - return x4_np - def postprocess(self, npu_output: np.ndarray) -> np.ndarray: - """Apply reverse_index reordering.""" - return npu_output[self.reverse_index] +class CpuPreprocessor(Protocol): + """Interface for CPU-side vision preprocessing before NPU execution.""" + def preprocess(self, pixel_values: torch.Tensor) -> np.ndarray: ... -# Factory function to get appropriate preprocessor -def get_cpu_preprocessor(model_cache_dir: str, optimized: bool = True): - """ - Get CPU preprocessor for Qwen2.5-VL vision model. + def postprocess(self, npu_output: np.ndarray) -> np.ndarray: ... - Args: - model_cache_dir: Path to NPU model cache - optimized: Use torch-based optimized version (default: True) - Returns: - Preprocessor instance - """ - if optimized: - try: - return Qwen2_5_VL_CPUPreprocessor_Optimized(model_cache_dir) - except Exception as e: - logger.warning( - "Failed to load optimized preprocessor: %s, falling back to numpy", - e, - ) - return Qwen2_5_VL_CPUPreprocessor(model_cache_dir) - else: - return Qwen2_5_VL_CPUPreprocessor(model_cache_dir) +def get_cpu_preprocessor(model_cache_dir: str) -> CpuPreprocessor: + """Return the CPU preprocessor for the compiled model at model_cache_dir.""" + return Qwen2_5_VLCpuPreprocessor(model_cache_dir) diff --git a/vllm/vision_npu/flexmlrt_backend.py b/vllm/vision_npu/flexmlrt_backend.py index 9cf3e6d910ea..7f587a919f98 100644 --- a/vllm/vision_npu/flexmlrt_backend.py +++ b/vllm/vision_npu/flexmlrt_backend.py @@ -144,7 +144,6 @@ def output_dim(self) -> int: """Get output embedding dimension from FlexMLRT model.""" return self.model.output_dim() - class AsyncFlexMLRTVisionBackend: """Async wrapper for FlexMLRT backend enabling NPU+GPU pipelining. @@ -340,3 +339,4 @@ def __del__(self): """Cleanup thread pool on deletion.""" if hasattr(self, "npu_executor"): self.npu_executor.shutdown(wait=True) + diff --git a/vllm/vision_npu/models/__init__.py b/vllm/vision_npu/models/__init__.py new file mode 100644 index 000000000000..8f2db09550ed --- /dev/null +++ b/vllm/vision_npu/models/__init__.py @@ -0,0 +1,4 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Model-specific CPU preprocessors for NPU vision backends.""" diff --git a/vllm/vision_npu/models/qwen2_5_vl_cpu_preprocess.py b/vllm/vision_npu/models/qwen2_5_vl_cpu_preprocess.py new file mode 100644 index 000000000000..77b6c836a9b8 --- /dev/null +++ b/vllm/vision_npu/models/qwen2_5_vl_cpu_preprocess.py @@ -0,0 +1,93 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +CPU preprocessing for Qwen2.5-VL vision models compiled for FlexMLRT. + +VitisAI-compiled models partition operations between CPU and NPU. FlexMLRT +requires the CPU path to be implemented explicitly: + +- Input: pixel_values [4292, 1176] from the HuggingFace processor +- Output: preprocessed [1073, 4, 1280] ready for NPU +- Postprocessing: apply reverse_index gather to NPU output +""" + +from __future__ import annotations + +import gc +import logging +import os + +import numpy as np +import onnx +import torch +import torch.nn.functional as F + +logger = logging.getLogger(__name__) + +_ONNX_FILENAME = "qwen2_5_vl_vision_stitched_7b.onnx" + + +def _resolve_onnx_model_path(model_cache_dir: str) -> str: + """Locate the stitched vision ONNX model next to the NPU cache directory.""" + # model_cache_dir is typically: .../qwen2_5_vl_vision_stitched_7b/vaiml_par_0 + candidates = ( + os.path.join( + os.path.dirname(os.path.dirname(model_cache_dir)), + _ONNX_FILENAME, + ), + os.path.join(os.path.dirname(model_cache_dir), _ONNX_FILENAME), + ) + for path in candidates: + if os.path.exists(path): + return path + raise FileNotFoundError( + f"Cannot find ONNX model {_ONNX_FILENAME} near {model_cache_dir}" + ) + + +class Qwen2_5_VLCpuPreprocessor: + """CPU preprocessing for Qwen2.5-VL vision models before NPU execution.""" + + def __init__(self, model_cache_dir: str): + onnx_model_path = _resolve_onnx_model_path(model_cache_dir) + logger.info("[Qwen2.5-VL CPU Preprocess] Loading ONNX model from %s", onnx_model_path) + + model = onnx.load(onnx_model_path) + graph = model.graph + initializers = {init.name: init for init in graph.initializer} + + weight_np = onnx.numpy_helper.to_array(initializers["patch_embed.proj.weight"]) + self.conv_weight = torch.from_numpy(weight_np).float() + self.window_index = onnx.numpy_helper.to_array( + initializers["blocks.window_index"] + ) + self.reverse_index = onnx.numpy_helper.to_array( + initializers["merger.reverse_index"] + ) + + del model, graph, initializers, weight_np + gc.collect() + logger.info( + "[Qwen2.5-VL CPU Preprocess] Initialized (ONNX model released from memory)" + ) + + def preprocess(self, pixel_values: torch.Tensor) -> np.ndarray: + """Apply CPU preprocessing to pixel_values before NPU execution.""" + pixel_values = pixel_values.cpu().float() + x = pixel_values.reshape(-1, 3, 2, 14, 14) + conv_out = F.conv3d( + x, + self.conv_weight, + bias=None, + stride=(2, 14, 14), + padding=(0, 0, 0), + ) + x2 = conv_out.reshape(-1, 1280) + x3 = x2.reshape(1073, 4, 1280) + x4_np = x3.numpy()[self.window_index] + logger.info("[Qwen2.5-VL CPU Preprocess] Output shape: %s", x4_np.shape) + return x4_np + + def postprocess(self, npu_output: np.ndarray) -> np.ndarray: + """Apply reverse_index reordering to NPU output.""" + return npu_output[self.reverse_index]