Skip to content

Commit 3ffb803

Browse files
committed
Fix BART processor compatibility with vLLM 0.18
1 parent 331e24a commit 3ffb803

2 files changed

Lines changed: 70 additions & 18 deletions

File tree

tests/test_vllm_018_compat.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
"""Regression tests for vLLM 0.18 compatibility in the BART processor."""
2+
3+
import torch
4+
5+
6+
def test_text_data_parser_handles_v018_empty_inputs():
7+
from vllm_bart_plugin.bart import TextDataParser
8+
9+
parser = TextDataParser()
10+
11+
assert parser._parse_text_data("") is None
12+
assert parser._parse_text_data([]) is None
13+
14+
15+
def test_create_encoder_prompt_uses_placeholder_token():
16+
from vllm_bart_plugin.bart import BartMultiModalProcessor
17+
18+
processor = BartMultiModalProcessor.__new__(BartMultiModalProcessor)
19+
20+
assert processor.create_encoder_prompt("<s>decoder text", {"texts": ["encoder text"]}) == [0]
21+
22+
23+
def test_call_hf_processor_accepts_pretokenized_decoder_prompt():
24+
from vllm_bart_plugin.bart import BartMultiModalProcessor
25+
26+
class FakeTokenizer:
27+
def __call__(self, text, return_tensors="pt", **kwargs):
28+
if text == "encoder text":
29+
return {"input_ids": torch.tensor([[11, 12, 13]])}
30+
return {"input_ids": torch.tensor([[21, 22]])}
31+
32+
class FakeInfo:
33+
def get_tokenizer(self):
34+
return FakeTokenizer()
35+
36+
processor = BartMultiModalProcessor.__new__(BartMultiModalProcessor)
37+
processor.info = FakeInfo()
38+
39+
out = processor._call_hf_processor(
40+
[7, 8, 9],
41+
{"texts": ["encoder text"]},
42+
{},
43+
{},
44+
)
45+
46+
assert torch.equal(out["encoder_input_ids"], torch.tensor([[11, 12, 13]]))
47+
assert torch.equal(out["input_ids"], torch.tensor([[7, 8, 9]]))

vllm_bart_plugin/bart.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -996,7 +996,10 @@ def _parse_text_data(
996996
if data is None:
997997
return TextProcessorItems(None)
998998

999-
if self._is_empty(data):
999+
# _is_empty was removed in vLLM >=0.18; handle emptiness inline
1000+
if isinstance(data, str) and not data:
1001+
return None
1002+
if isinstance(data, list) and len(data) == 0:
10001003
return None
10011004

10021005
# Text data should be a string or list of strings
@@ -1030,15 +1033,11 @@ def create_encoder_prompt(
10301033
prompt: str | list[int],
10311034
mm_data: MultiModalDataDict,
10321035
) -> str | list[int]:
1033-
if not prompt:
1034-
return [0]
1035-
tokenizer = self.info.get_tokenizer()
1036-
tokens = tokenizer(
1037-
prompt,
1038-
add_special_tokens=False,
1039-
return_tensors="pt",
1040-
)["input_ids"].flatten()
1041-
return tokens.tolist()
1036+
# In vLLM >=0.18, `prompt` here is the DECODER prompt text, not the
1037+
# encoder text. The encoder content lives in mm_data ("text" key).
1038+
# Always return [0] as a single placeholder token; _get_prompt_updates
1039+
# will replace it with the correct number of encoder token slots.
1040+
return [0]
10421041

10431042
def create_decoder_prompt(
10441043
self,
@@ -1079,14 +1078,20 @@ def _call_hf_processor(
10791078
)
10801079
result["encoder_input_ids"] = encoder_tokenized["input_ids"]
10811080

1082-
# Always tokenize the prompt (for decoder or as dummy)
1083-
# This will be popped by the base class
1084-
prompt_tokenized = tokenizer(
1085-
prompt if prompt else "",
1086-
return_tensors="pt",
1087-
**tok_kwargs,
1088-
)
1089-
result["input_ids"] = prompt_tokenized["input_ids"]
1081+
# Always produce input_ids for the decoder prompt.
1082+
# In vLLM >=0.18 the rendering pipeline may call _call_hf_processor
1083+
# with an already-tokenized prompt (a list of ints) instead of a str.
1084+
# Handle both cases.
1085+
import torch as _torch
1086+
if isinstance(prompt, (list, tuple)) and len(prompt) > 0 and isinstance(prompt[0], int):
1087+
result["input_ids"] = _torch.tensor([prompt])
1088+
else:
1089+
prompt_tokenized = tokenizer(
1090+
prompt if prompt else "",
1091+
return_tensors="pt",
1092+
**tok_kwargs,
1093+
)
1094+
result["input_ids"] = prompt_tokenized["input_ids"]
10901095

10911096
return BatchFeature(result)
10921097

0 commit comments

Comments
 (0)