Skip to content

Commit 1dea0dd

Browse files
yechank-nvidia2ez4bz
authored andcommitted
[None][fix] Support Qwen VL image embedding inputs
Signed-off-by: yechank <161688079+yechank-nvidia@users.noreply.github.com>
1 parent b036594 commit 1dea0dd

3 files changed

Lines changed: 88 additions & 101 deletions

File tree

tensorrt_llm/_torch/models/modeling_qwen2vl.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -476,6 +476,83 @@ def get_mrope_config(
476476
'cpu').to(torch.int32).clone()
477477
return mrope_config
478478

479+
@staticmethod
480+
def _infer_image_grid_thw(num_tokens: int,
481+
spatial_merge_size: int) -> List[int]:
482+
if num_tokens <= 0:
483+
raise ValueError(
484+
f"Image embedding must contain at least one token, got {num_tokens}"
485+
)
486+
llm_grid_h = int(num_tokens**0.5)
487+
while llm_grid_h > 1 and num_tokens % llm_grid_h != 0:
488+
llm_grid_h -= 1
489+
llm_grid_w = num_tokens // llm_grid_h
490+
return [
491+
1,
492+
llm_grid_h * spatial_merge_size,
493+
llm_grid_w * spatial_merge_size,
494+
]
495+
496+
def _attach_multimodal_embeddings_impl(
497+
self,
498+
inputs: TextPrompt,
499+
multimodal_embedding: Dict[str, List[torch.Tensor]],
500+
sampling_params: SamplingParams,
501+
) -> Tuple[List[int], Optional[ExtraProcessedInputs]]:
502+
if not isinstance(multimodal_embedding, dict):
503+
raise ValueError("multimodal_embedding must be a dictionary")
504+
if set(multimodal_embedding) != {"image"}:
505+
raise ValueError(
506+
"Only image modality is supported for external multimodal embedding"
507+
)
508+
509+
image_embeddings = multimodal_embedding["image"]
510+
if isinstance(image_embeddings, torch.Tensor):
511+
image_embeddings = [image_embeddings]
512+
if not image_embeddings:
513+
raise ValueError("At least one image embedding is required")
514+
for index, image_embedding in enumerate(image_embeddings):
515+
if image_embedding.dim() != 2:
516+
raise ValueError(
517+
f"Image embedding {index} must be rank 2, got shape {tuple(image_embedding.shape)}"
518+
)
519+
520+
get_prompt_token_ids = getattr(self, "get_prompt_token_ids", None)
521+
if not callable(get_prompt_token_ids):
522+
raise NotImplementedError(
523+
f"{type(self).__name__} does not support external multimodal embeddings"
524+
)
525+
526+
mm_handles = [{
527+
"tensor_size": tuple(image_embedding.shape)
528+
} for image_embedding in image_embeddings]
529+
prompt_token_ids, _, _ = get_prompt_token_ids(inputs, mm_handles)
530+
531+
mrope_input_ids = torch.tensor(prompt_token_ids,
532+
dtype=torch.long).unsqueeze(0)
533+
mrope_input_ids = mrope_input_ids.clone()
534+
multimodal_token_id = self.tllm_multimodal_token_id
535+
mrope_input_ids[mrope_input_ids ==
536+
multimodal_token_id] = self.config.image_token_id
537+
spatial_merge_size = self.config.vision_config.spatial_merge_size
538+
image_grid_thw = torch.tensor(
539+
[
540+
self._infer_image_grid_thw(image_embedding.shape[0],
541+
spatial_merge_size)
542+
for image_embedding in image_embeddings
543+
],
544+
dtype=torch.long,
545+
)
546+
attention_mask = torch.ones_like(mrope_input_ids)
547+
mrope_config = self.get_mrope_config(mrope_input_ids, image_grid_thw,
548+
None, attention_mask, None)
549+
550+
multimodal_data = {
551+
"multimodal_embedding": image_embeddings,
552+
"mrope_config": mrope_config,
553+
}
554+
return prompt_token_ids, {"multimodal_data": multimodal_data}
555+
479556
@nvtx_range("Qwen2VLInputProcessorBase forward()")
480557
@torch.inference_mode()
481558
def call_with_text_prompt(

tests/unittest/llmapi/apps/_attach_multimodal_embeddings_patch/__init__.py

Lines changed: 0 additions & 53 deletions
This file was deleted.

tests/unittest/llmapi/apps/_test_openai_chat_multimodal.py

Lines changed: 11 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import io
22
import os
3-
import sys
43
import tempfile
54
from base64 import b64encode
65
from pathlib import Path
@@ -147,35 +146,14 @@ def temp_extra_encoder_options_file() -> str:
147146
return "/dummy/path"
148147

149148

150-
@pytest.fixture(scope="module")
151-
def server_patched(model_name: str, temp_extra_llm_api_options_file: str):
152-
# Custom module implements missing 'attach_multimodal_embeddings' to intercept
153-
# embeddings.
154-
model_path = get_model_path(model_name)
155-
args = [
156-
"--extra_llm_api_options",
157-
temp_extra_llm_api_options_file,
158-
"--max_batch_size",
159-
"64",
160-
"--max_num_tokens",
161-
"16384",
162-
"--custom_module_dirs",
163-
str(
164-
Path(sys.modules[test_single_chat_session_image_embeds.__module__].
165-
__file__).parent / "_attach_multimodal_embeddings_patch"),
166-
]
167-
with RemoteOpenAIServer(model_path, args) as remote_server:
168-
yield remote_server
169-
170-
171149
@pytest.mark.needs_l40s
172150
@pytest.mark.asyncio(loop_scope="module")
173151
def test_single_chat_session_image_embeds(
174-
server_patched: RemoteOpenAIServer,
152+
server: RemoteOpenAIServer,
175153
model_name: str,
176154
mm_encoder_server: RemoteMMEncoderServer,
177155
):
178-
client = server_patched.get_client()
156+
client = server.get_client()
179157
messages, mm_embed_handle = _test_multimodal_content_mm_encoder(
180158
mm_encoder_server.get_client(), model_name)
181159

@@ -201,30 +179,15 @@ def test_single_chat_session_image_embeds(
201179
"data": b64encode(mm_embed_bytes).decode("ascii")
202180
}
203181

204-
# test single completion
205-
#
206-
# FIXME: Remove try-except and use 'server' instead of 'server_patched',
207-
# once Qwen2VLInputProcessorBase implements attach_multimodal_embeddings.
208-
try:
209-
chat_completion_embeds = client.chat.completions.create(
210-
model=model_name,
211-
messages=messages,
212-
max_completion_tokens=max_completion_tokens,
213-
temperature=0.0,
214-
logprobs=False)
215-
216-
assert chat_completion_embeds.choices[
217-
0].message == chat_completion_image.choices[0].message
218-
except openai.BadRequestError as e:
219-
assert isinstance(e.body, dict)
220-
with open(Path(e.body["message"]), "rb") as f:
221-
intercepted_embeddings = torch.load(f, weights_only=True)
222-
assert list(intercepted_embeddings.keys()) == ["image"]
223-
assert len(intercepted_embeddings["image"]) == 1
224-
torch.testing.assert_close(intercepted_embeddings["image"][0],
225-
mm_embed.cpu())
226-
pytest.xfail(
227-
reason="Model does not implement 'attach_multimodal_embeddings'")
182+
chat_completion_embeds = client.chat.completions.create(
183+
model=model_name,
184+
messages=messages,
185+
max_completion_tokens=max_completion_tokens,
186+
temperature=0.0,
187+
logprobs=False)
188+
189+
assert chat_completion_embeds.choices[
190+
0].message == chat_completion_image.choices[0].message
228191

229192

230193
@pytest.mark.asyncio(loop_scope="module")

0 commit comments

Comments
 (0)