Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions tests/test_vllm_018_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""Regression tests for vLLM 0.18 compatibility in the BART processor."""

import torch


def test_text_data_parser_handles_v018_empty_inputs():
from vllm_bart_plugin.bart import TextDataParser

parser = TextDataParser()

assert parser._parse_text_data("") is None
assert parser._parse_text_data([]) is None


def test_create_encoder_prompt_uses_placeholder_token():
from vllm_bart_plugin.bart import BartMultiModalProcessor

processor = BartMultiModalProcessor.__new__(BartMultiModalProcessor)

assert processor.create_encoder_prompt("<s>decoder text", {"texts": ["encoder text"]}) == [0]


def test_call_hf_processor_accepts_pretokenized_decoder_prompt():
from vllm_bart_plugin.bart import BartMultiModalProcessor

class FakeTokenizer:
def __call__(self, text, return_tensors="pt", **kwargs):
if text == "encoder text":
return {"input_ids": torch.tensor([[11, 12, 13]])}
return {"input_ids": torch.tensor([[21, 22]])}

class FakeInfo:
def get_tokenizer(self):
return FakeTokenizer()

processor = BartMultiModalProcessor.__new__(BartMultiModalProcessor)
processor.info = FakeInfo()

out = processor._call_hf_processor(
[7, 8, 9],
{"texts": ["encoder text"]},
{},
{},
)

assert torch.equal(out["encoder_input_ids"], torch.tensor([[11, 12, 13]]))
assert torch.equal(out["input_ids"], torch.tensor([[7, 8, 9]]))
40 changes: 22 additions & 18 deletions vllm_bart_plugin/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -996,7 +996,10 @@ def _parse_text_data(
if data is None:
return TextProcessorItems(None)

if self._is_empty(data):
# _is_empty was removed in vLLM >=0.18; handle emptiness inline
if isinstance(data, str) and not data:
return None
if isinstance(data, list) and len(data) == 0:
return None

# Text data should be a string or list of strings
Expand Down Expand Up @@ -1030,15 +1033,11 @@ def create_encoder_prompt(
prompt: str | list[int],
mm_data: MultiModalDataDict,
) -> str | list[int]:
if not prompt:
return [0]
tokenizer = self.info.get_tokenizer()
tokens = tokenizer(
prompt,
add_special_tokens=False,
return_tensors="pt",
)["input_ids"].flatten()
return tokens.tolist()
# In vLLM >=0.18, `prompt` here is the DECODER prompt text, not the
# encoder text. The encoder content lives in mm_data ("text" key).
# Always return [0] as a single placeholder token; _get_prompt_updates
# will replace it with the correct number of encoder token slots.
return [0]

def create_decoder_prompt(
self,
Expand Down Expand Up @@ -1079,14 +1078,19 @@ def _call_hf_processor(
)
result["encoder_input_ids"] = encoder_tokenized["input_ids"]

# Always tokenize the prompt (for decoder or as dummy)
# This will be popped by the base class
prompt_tokenized = tokenizer(
prompt if prompt else "",
return_tensors="pt",
**tok_kwargs,
)
result["input_ids"] = prompt_tokenized["input_ids"]
# Always produce input_ids for the decoder prompt.
# In vLLM >=0.18 the rendering pipeline may call _call_hf_processor
# with an already-tokenized prompt (a list of ints) instead of a str.
# Handle both cases.
if isinstance(prompt, (list, tuple)) and len(prompt) > 0 and isinstance(prompt[0], int):
result["input_ids"] = torch.tensor([prompt])
else:
prompt_tokenized = tokenizer(
prompt if prompt else "",
return_tensors="pt",
**tok_kwargs,
)
result["input_ids"] = prompt_tokenized["input_ids"]

return BatchFeature(result)

Expand Down