diff --git a/tests/test_vllm_018_compat.py b/tests/test_vllm_018_compat.py new file mode 100644 index 0000000..921ffba --- /dev/null +++ b/tests/test_vllm_018_compat.py @@ -0,0 +1,47 @@ +"""Regression tests for vLLM 0.18 compatibility in the BART processor.""" + +import torch + + +def test_text_data_parser_handles_v018_empty_inputs(): + from vllm_bart_plugin.bart import TextDataParser + + parser = TextDataParser() + + assert parser._parse_text_data("") is None + assert parser._parse_text_data([]) is None + + +def test_create_encoder_prompt_uses_placeholder_token(): + from vllm_bart_plugin.bart import BartMultiModalProcessor + + processor = BartMultiModalProcessor.__new__(BartMultiModalProcessor) + + assert processor.create_encoder_prompt("decoder text", {"texts": ["encoder text"]}) == [0] + + +def test_call_hf_processor_accepts_pretokenized_decoder_prompt(): + from vllm_bart_plugin.bart import BartMultiModalProcessor + + class FakeTokenizer: + def __call__(self, text, return_tensors="pt", **kwargs): + if text == "encoder text": + return {"input_ids": torch.tensor([[11, 12, 13]])} + return {"input_ids": torch.tensor([[21, 22]])} + + class FakeInfo: + def get_tokenizer(self): + return FakeTokenizer() + + processor = BartMultiModalProcessor.__new__(BartMultiModalProcessor) + processor.info = FakeInfo() + + out = processor._call_hf_processor( + [7, 8, 9], + {"texts": ["encoder text"]}, + {}, + {}, + ) + + assert torch.equal(out["encoder_input_ids"], torch.tensor([[11, 12, 13]])) + assert torch.equal(out["input_ids"], torch.tensor([[7, 8, 9]])) diff --git a/vllm_bart_plugin/bart.py b/vllm_bart_plugin/bart.py index fcced7d..62cca02 100644 --- a/vllm_bart_plugin/bart.py +++ b/vllm_bart_plugin/bart.py @@ -996,7 +996,10 @@ def _parse_text_data( if data is None: return TextProcessorItems(None) - if self._is_empty(data): + # _is_empty was removed in vLLM >=0.18; handle emptiness inline + if isinstance(data, str) and not data: + return None + if isinstance(data, list) and len(data) == 0: return None # Text data should be a string or list of strings @@ -1030,15 +1033,11 @@ def create_encoder_prompt( prompt: str | list[int], mm_data: MultiModalDataDict, ) -> str | list[int]: - if not prompt: - return [0] - tokenizer = self.info.get_tokenizer() - tokens = tokenizer( - prompt, - add_special_tokens=False, - return_tensors="pt", - )["input_ids"].flatten() - return tokens.tolist() + # In vLLM >=0.18, `prompt` here is the DECODER prompt text, not the + # encoder text. The encoder content lives in mm_data ("text" key). + # Always return [0] as a single placeholder token; _get_prompt_updates + # will replace it with the correct number of encoder token slots. + return [0] def create_decoder_prompt( self, @@ -1079,14 +1078,19 @@ def _call_hf_processor( ) result["encoder_input_ids"] = encoder_tokenized["input_ids"] - # Always tokenize the prompt (for decoder or as dummy) - # This will be popped by the base class - prompt_tokenized = tokenizer( - prompt if prompt else "", - return_tensors="pt", - **tok_kwargs, - ) - result["input_ids"] = prompt_tokenized["input_ids"] + # Always produce input_ids for the decoder prompt. + # In vLLM >=0.18 the rendering pipeline may call _call_hf_processor + # with an already-tokenized prompt (a list of ints) instead of a str. + # Handle both cases. + if isinstance(prompt, (list, tuple)) and len(prompt) > 0 and isinstance(prompt[0], int): + result["input_ids"] = torch.tensor([prompt]) + else: + prompt_tokenized = tokenizer( + prompt if prompt else "", + return_tensors="pt", + **tok_kwargs, + ) + result["input_ids"] = prompt_tokenized["input_ids"] return BatchFeature(result)