Skip to content

Commit bab97cd

Browse files
authored
Merge pull request #22 from vllm-project/v0.18.0-fix
V0.18.0 fix
2 parents 331e24a + 8d34ba4 commit bab97cd

2 files changed

Lines changed: 112 additions & 26 deletions

File tree

tests/test_vllm_018_compat.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
"""Regression tests for vLLM 0.18 compatibility in the BART processor."""
2+
3+
import torch
4+
5+
6+
def test_text_data_parser_handles_v018_empty_inputs():
7+
from vllm_bart_plugin.bart import TextDataParser
8+
9+
parser = TextDataParser()
10+
11+
assert parser._parse_text_data("") is None
12+
assert parser._parse_text_data([]) is None
13+
14+
15+
def test_create_encoder_prompt_uses_placeholder_token():
16+
from vllm_bart_plugin.bart import BartMultiModalProcessor
17+
18+
processor = BartMultiModalProcessor.__new__(BartMultiModalProcessor)
19+
20+
assert processor.create_encoder_prompt("<s>decoder text", {"texts": ["encoder text"]}) == [0]
21+
22+
23+
def test_call_hf_processor_accepts_pretokenized_decoder_prompt():
24+
from vllm_bart_plugin.bart import BartMultiModalProcessor
25+
26+
class FakeTokenizer:
27+
def __call__(self, text, return_tensors="pt", **kwargs):
28+
if text == "encoder text":
29+
return {"input_ids": torch.tensor([[11, 12, 13]])}
30+
return {"input_ids": torch.tensor([[21, 22]])}
31+
32+
class FakeInfo:
33+
def get_tokenizer(self):
34+
return FakeTokenizer()
35+
36+
processor = BartMultiModalProcessor.__new__(BartMultiModalProcessor)
37+
processor.info = FakeInfo()
38+
39+
out = processor._call_hf_processor(
40+
[7, 8, 9],
41+
{"texts": ["encoder text"]},
42+
{},
43+
{},
44+
)
45+
46+
assert torch.equal(out["encoder_input_ids"], torch.tensor([[11, 12, 13]]))
47+
assert torch.equal(out["input_ids"], torch.tensor([[7, 8, 9]]))

vllm_bart_plugin/bart.py

Lines changed: 65 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -945,6 +945,21 @@ def get_data_parser(self) -> MultiModalDataParser:
945945
return TextDataParser()
946946

947947

948+
# vLLM >=0.18 moved tokenization defaults from a global enc-dec override
949+
# (InputPreprocessor._get_tokenization_kw) into per-model ProcessingInfo.
950+
# The old code forced add_special_tokens=False for every is_encoder_decoder
951+
# model; replicate that here so the renderer does not inject extra BOS/EOS
952+
# into the decoder prompt. On vLLM <0.18 the method does not exist on the
953+
# base class and is not needed (the global override handles it).
954+
if hasattr(BaseProcessingInfo, "get_default_tok_params"):
955+
956+
def _bart_get_default_tok_params(self):
957+
return super(BartProcessingInfo, self).get_default_tok_params() \
958+
.with_kwargs(add_special_tokens=False)
959+
960+
BartProcessingInfo.get_default_tok_params = _bart_get_default_tok_params # type: ignore[attr-defined]
961+
962+
948963
class BartDummyInputsBuilder(BaseDummyInputsBuilder[BartProcessingInfo]):
949964
"""Builds dummy inputs for profiling BART models."""
950965

@@ -993,12 +1008,10 @@ def _parse_text_data(
9931008
data: ModalityData[str],
9941009
) -> ModalityDataItems[Any, Any] | None:
9951010
"""Parse text data for BART."""
996-
if data is None:
1011+
# _is_empty was removed in vLLM >=0.18; handle emptiness inline
1012+
if data is None or not len(data):
9971013
return TextProcessorItems(None)
9981014

999-
if self._is_empty(data):
1000-
return None
1001-
10021015
# Text data should be a string or list of strings
10031016
if isinstance(data, str) or is_list_of(data, str):
10041017
return TextProcessorItems(data)
@@ -1030,15 +1043,23 @@ def create_encoder_prompt(
10301043
prompt: str | list[int],
10311044
mm_data: MultiModalDataDict,
10321045
) -> str | list[int]:
1033-
if not prompt:
1034-
return [0]
1035-
tokenizer = self.info.get_tokenizer()
1036-
tokens = tokenizer(
1037-
prompt,
1038-
add_special_tokens=False,
1039-
return_tensors="pt",
1040-
)["input_ids"].flatten()
1041-
return tokens.tolist()
1046+
# vLLM compatibility:
1047+
# - Legacy (<0.18): prompt is encoder text (str) — tokenize directly.
1048+
# - Modern (>=0.18): prompt is decoder token IDs or empty str from
1049+
# profiling — return a single [0] placeholder that _get_prompt_updates
1050+
# will expand to the real encoder token count. The placeholder IDs
1051+
# are structural (KV-cache sizing); the actual encoder computation
1052+
# uses encoder_input_ids from mm_kwargs.
1053+
if isinstance(prompt, str) and prompt:
1054+
tokenizer = self.info.get_tokenizer()
1055+
tokens = tokenizer(
1056+
prompt,
1057+
add_special_tokens=False,
1058+
return_tensors="pt",
1059+
)["input_ids"].flatten()
1060+
return tokens.tolist()
1061+
1062+
return [0]
10421063

10431064
def create_decoder_prompt(
10441065
self,
@@ -1056,10 +1077,16 @@ def _call_hf_processor(
10561077
tok_kwargs: Mapping[str, object],
10571078
):
10581079
"""
1059-
BART doesn't have a HuggingFace Processor - it only has a tokenizer.
1060-
We tokenize both the prompt (decoder) and encoder text from mm_data.
1080+
BART doesn't have a HuggingFace Processor — it only has a tokenizer.
1081+
1082+
Produces two sets of token IDs:
1083+
- ``encoder_input_ids``: tokenized encoder text from ``mm_data["texts"]``
1084+
- ``input_ids``: tokenized decoder prompt (used by the base class to
1085+
build ``prompt_token_ids``)
1086+
1087+
Encoder text is always tokenized with ``add_special_tokens=False`` to
1088+
match v0.16 behaviour and stay consistent with ``_get_prompt_updates``.
10611089
"""
1062-
# tok_kwargs["add_special_tokens"] = False
10631090
from transformers.feature_extraction_utils import BatchFeature
10641091

10651092
tokenizer = self.info.get_tokenizer()
@@ -1069,24 +1096,29 @@ def _call_hf_processor(
10691096
result = {}
10701097

10711098
if has_encoder_data:
1072-
# Tokenize the encoder text from mm_data
10731099
encoder_texts = mm_data["texts"]
10741100
encoder_text = encoder_texts[0] if encoder_texts else ""
1101+
# Tokenize the encoder text from mm_data
10751102
encoder_tokenized = tokenizer(
10761103
encoder_text,
10771104
return_tensors="pt",
1078-
**tok_kwargs,
1105+
add_special_tokens=False,
10791106
)
10801107
result["encoder_input_ids"] = encoder_tokenized["input_ids"]
10811108

1082-
# Always tokenize the prompt (for decoder or as dummy)
1083-
# This will be popped by the base class
1084-
prompt_tokenized = tokenizer(
1085-
prompt if prompt else "",
1086-
return_tensors="pt",
1087-
**tok_kwargs,
1088-
)
1089-
result["input_ids"] = prompt_tokenized["input_ids"]
1109+
# Always produce input_ids for the decoder prompt.
1110+
# In vLLM >=0.18 the rendering pipeline may call _call_hf_processor
1111+
# with an already-tokenized prompt (a list of ints) instead of a str.
1112+
# Handle both cases.
1113+
if isinstance(prompt, (list, tuple)) and len(prompt) > 0 and isinstance(prompt[0], int):
1114+
result["input_ids"] = torch.tensor([prompt])
1115+
else:
1116+
prompt_tokenized = tokenizer(
1117+
prompt if prompt else "",
1118+
return_tensors="pt",
1119+
**tok_kwargs,
1120+
)
1121+
result["input_ids"] = prompt_tokenized["input_ids"]
10901122

10911123
return BatchFeature(result)
10921124

@@ -1105,6 +1137,13 @@ def _get_prompt_updates(
11051137
hf_processor_mm_kwargs: Mapping[str, object],
11061138
out_mm_kwargs: MultiModalKwargsItems,
11071139
) -> Sequence[PromptUpdate]:
1140+
"""Replace the single [0] encoder placeholder with N placeholder
1141+
tokens, where N equals the tokenized length of the encoder text.
1142+
1143+
The token count must use ``add_special_tokens=False`` to stay
1144+
consistent with ``_call_hf_processor`` (which tokenizes the encoder
1145+
text the same way).
1146+
"""
11081147
from vllm.multimodal.processing import PromptReplacement
11091148

11101149
# Get the number of text items to determine token count

0 commit comments

Comments
 (0)