Skip to content

Commit 2b7db68

Browse files
committed
qwen2/3vl: memoize HF processor _merge_kwargs by call signature
ProcessorMixin._merge_kwargs (transformers) is pure but runs on every processor call. When all requests pass the same kwargs (the common deployment case), caching by signature converts a per-call merge into an O(1) lookup after the first call. Implemented as a wrapper installed on the processor instance at construction time, so it doesn't require any change to transformers. Cache key is the repr of sorted kwargs items; values are deep-copied on get and put because callers mutate the returned dict. Signed-off-by: Aswin Visva <31215515+aswinvisva@users.noreply.github.com>
1 parent 50acdb5 commit 2b7db68

1 file changed

Lines changed: 36 additions & 0 deletions

File tree

tensorrt_llm/_torch/models/modeling_qwen2vl.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,41 @@
7777
PAD_INDEX = -100 # NOTE: refer to https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py#L269
7878

7979

80+
def _install_merge_kwargs_cache(processor) -> None:
81+
"""Memoize ``processor._merge_kwargs`` by input kwargs signature.
82+
83+
``ProcessorMixin._merge_kwargs`` is pure but runs on every processor
84+
call. When all requests pass the same kwargs (the common deployment
85+
case), caching by signature reduces it to an O(1) lookup after the
86+
first call.
87+
88+
Values are deep-copied on get and put because callers mutate the
89+
returned dict.
90+
"""
91+
import copy
92+
93+
if getattr(processor, "_merge_kwargs_cached_installed", False):
94+
return
95+
96+
cache: dict = {}
97+
orig = processor._merge_kwargs
98+
99+
def _cached_merge_kwargs(*args, **kwargs):
100+
try:
101+
key = repr(sorted(kwargs.items()))
102+
except Exception:
103+
return orig(*args, **kwargs)
104+
hit = cache.get(key)
105+
if hit is not None:
106+
return copy.deepcopy(hit)
107+
result = orig(*args, **kwargs)
108+
cache[key] = copy.deepcopy(result)
109+
return result
110+
111+
processor._merge_kwargs = _cached_merge_kwargs
112+
processor._merge_kwargs_cached_installed = True
113+
114+
80115
def _prepare_qwen_vl_vision_attn_metadata(
81116
seq_lens: List[int],
82117
attn_metadata: AttentionMetadata) -> AttentionMetadata:
@@ -186,6 +221,7 @@ def __init__(self,
186221
model_path,
187222
use_fast=self.use_fast,
188223
trust_remote_code=trust_remote_code)
224+
_install_merge_kwargs_cache(self._processor)
189225

190226
self.tllm_multimodal_token_id = self.get_vocab_size() + 1
191227
# temporal patch size for video frames

0 commit comments

Comments
 (0)