Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions tests/test_vllm_018_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""Regression tests for vLLM 0.18 compatibility in the BART processor."""

import torch


def test_text_data_parser_handles_v018_empty_inputs():
from vllm_bart_plugin.bart import TextDataParser

parser = TextDataParser()

assert parser._parse_text_data("") is None
assert parser._parse_text_data([]) is None


def test_create_encoder_prompt_uses_placeholder_token():
from vllm_bart_plugin.bart import BartMultiModalProcessor

processor = BartMultiModalProcessor.__new__(BartMultiModalProcessor)

assert processor.create_encoder_prompt("<s>decoder text", {"texts": ["encoder text"]}) == [0]


def test_call_hf_processor_accepts_pretokenized_decoder_prompt():
from vllm_bart_plugin.bart import BartMultiModalProcessor

class FakeTokenizer:
def __call__(self, text, return_tensors="pt", **kwargs):
if text == "encoder text":
return {"input_ids": torch.tensor([[11, 12, 13]])}
return {"input_ids": torch.tensor([[21, 22]])}

class FakeInfo:
def get_tokenizer(self):
return FakeTokenizer()

processor = BartMultiModalProcessor.__new__(BartMultiModalProcessor)
processor.info = FakeInfo()

out = processor._call_hf_processor(
[7, 8, 9],
{"texts": ["encoder text"]},
{},
{},
)

assert torch.equal(out["encoder_input_ids"], torch.tensor([[11, 12, 13]]))
assert torch.equal(out["input_ids"], torch.tensor([[7, 8, 9]]))
91 changes: 65 additions & 26 deletions vllm_bart_plugin/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -945,6 +945,21 @@ def get_data_parser(self) -> MultiModalDataParser:
return TextDataParser()


# vLLM >=0.18 moved tokenization defaults from a global enc-dec override
# (InputPreprocessor._get_tokenization_kw) into per-model ProcessingInfo.
# The old code forced add_special_tokens=False for every is_encoder_decoder
# model; replicate that here so the renderer does not inject extra BOS/EOS
# into the decoder prompt. On vLLM <0.18 the method does not exist on the
# base class and is not needed (the global override handles it).
if hasattr(BaseProcessingInfo, "get_default_tok_params"):

def _bart_get_default_tok_params(self):
return super(BartProcessingInfo, self).get_default_tok_params() \
.with_kwargs(add_special_tokens=False)

BartProcessingInfo.get_default_tok_params = _bart_get_default_tok_params # type: ignore[attr-defined]


class BartDummyInputsBuilder(BaseDummyInputsBuilder[BartProcessingInfo]):
"""Builds dummy inputs for profiling BART models."""

Expand Down Expand Up @@ -993,12 +1008,10 @@ def _parse_text_data(
data: ModalityData[str],
) -> ModalityDataItems[Any, Any] | None:
"""Parse text data for BART."""
if data is None:
# _is_empty was removed in vLLM >=0.18; handle emptiness inline
if data is None or not len(data):
return TextProcessorItems(None)

if self._is_empty(data):
return None

# Text data should be a string or list of strings
if isinstance(data, str) or is_list_of(data, str):
return TextProcessorItems(data)
Expand Down Expand Up @@ -1030,15 +1043,23 @@ def create_encoder_prompt(
prompt: str | list[int],
mm_data: MultiModalDataDict,
) -> str | list[int]:
if not prompt:
return [0]
tokenizer = self.info.get_tokenizer()
tokens = tokenizer(
prompt,
add_special_tokens=False,
return_tensors="pt",
)["input_ids"].flatten()
return tokens.tolist()
# vLLM compatibility:
# - Legacy (<0.18): prompt is encoder text (str) — tokenize directly.
# - Modern (>=0.18): prompt is decoder token IDs or empty str from
# profiling — return a single [0] placeholder that _get_prompt_updates
# will expand to the real encoder token count. The placeholder IDs
# are structural (KV-cache sizing); the actual encoder computation
# uses encoder_input_ids from mm_kwargs.
if isinstance(prompt, str) and prompt:
tokenizer = self.info.get_tokenizer()
tokens = tokenizer(
prompt,
add_special_tokens=False,
return_tensors="pt",
)["input_ids"].flatten()
return tokens.tolist()

return [0]

def create_decoder_prompt(
self,
Expand All @@ -1056,10 +1077,16 @@ def _call_hf_processor(
tok_kwargs: Mapping[str, object],
):
"""
BART doesn't have a HuggingFace Processor - it only has a tokenizer.
We tokenize both the prompt (decoder) and encoder text from mm_data.
BART doesn't have a HuggingFace Processor — it only has a tokenizer.

Produces two sets of token IDs:
- ``encoder_input_ids``: tokenized encoder text from ``mm_data["texts"]``
- ``input_ids``: tokenized decoder prompt (used by the base class to
build ``prompt_token_ids``)

Encoder text is always tokenized with ``add_special_tokens=False`` to
match v0.16 behaviour and stay consistent with ``_get_prompt_updates``.
"""
# tok_kwargs["add_special_tokens"] = False
from transformers.feature_extraction_utils import BatchFeature

tokenizer = self.info.get_tokenizer()
Expand All @@ -1069,24 +1096,29 @@ def _call_hf_processor(
result = {}

if has_encoder_data:
# Tokenize the encoder text from mm_data
encoder_texts = mm_data["texts"]
encoder_text = encoder_texts[0] if encoder_texts else ""
# Tokenize the encoder text from mm_data
encoder_tokenized = tokenizer(
encoder_text,
return_tensors="pt",
**tok_kwargs,
add_special_tokens=False,
)
result["encoder_input_ids"] = encoder_tokenized["input_ids"]

# Always tokenize the prompt (for decoder or as dummy)
# This will be popped by the base class
prompt_tokenized = tokenizer(
prompt if prompt else "",
return_tensors="pt",
**tok_kwargs,
)
result["input_ids"] = prompt_tokenized["input_ids"]
# Always produce input_ids for the decoder prompt.
# In vLLM >=0.18 the rendering pipeline may call _call_hf_processor
# with an already-tokenized prompt (a list of ints) instead of a str.
# Handle both cases.
if isinstance(prompt, (list, tuple)) and len(prompt) > 0 and isinstance(prompt[0], int):
result["input_ids"] = torch.tensor([prompt])
else:
prompt_tokenized = tokenizer(
prompt if prompt else "",
return_tensors="pt",
**tok_kwargs,
)
result["input_ids"] = prompt_tokenized["input_ids"]

return BatchFeature(result)

Expand All @@ -1105,6 +1137,13 @@ def _get_prompt_updates(
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
"""Replace the single [0] encoder placeholder with N placeholder
tokens, where N equals the tokenized length of the encoder text.

The token count must use ``add_special_tokens=False`` to stay
consistent with ``_call_hf_processor`` (which tokenizes the encoder
text the same way).
"""
from vllm.multimodal.processing import PromptReplacement

# Get the number of text items to determine token count
Expand Down