Skip to content

Commit c68d685

Browse files
committed
Fix bugs and add better tests
1 parent bb29009 commit c68d685

2 files changed

Lines changed: 165 additions & 64 deletions

File tree

tests/test_florence2.py

Lines changed: 113 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,13 @@
66
import torch
77
from transformers import Florence2Config
88

9-
MODEL_NAME = "florence-community/Florence-2-base-ft"
9+
# Allow override via env var so CI can point at a local checkpoint.
10+
MODEL_NAME = os.environ.get(
11+
"FLORENCE2_MODEL",
12+
os.path.abspath(
13+
os.path.join(os.path.dirname(__file__), "../../Florence-2-base-ft")
14+
),
15+
)
1016

1117

1218
def _small_vision_config():
@@ -124,6 +130,20 @@ def test_output_shape(self):
124130
# ---------------------------------------------------------------------------
125131

126132

133+
def _run_task(llm, processor, image, task_prompt, text_input=None, max_tokens=100):
134+
"""Helper: run one Florence-2 task and return the post-processed result."""
135+
from vllm import SamplingParams
136+
137+
prompt = task_prompt if text_input is None else task_prompt + text_input
138+
params = SamplingParams(temperature=0.0, max_tokens=max_tokens, skip_special_tokens=False)
139+
outputs = llm.generate(
140+
[{"prompt": prompt, "multi_modal_data": {"image": image}}],
141+
sampling_params=params,
142+
)
143+
raw = outputs[0].outputs[0].text
144+
return processor.post_process_generation(raw, task=task_prompt, image_size=image.size)
145+
146+
127147
@pytest.fixture(scope="module")
128148
def florence2_llm():
129149
from vllm import LLM
@@ -138,72 +158,111 @@ def florence2_llm():
138158

139159

140160
@pytest.fixture(scope="module")
141-
def stop_sign_image():
142-
from vllm.assets.image import ImageAsset
161+
def florence2_processor():
162+
from transformers import AutoProcessor
143163

144-
return ImageAsset("stop_sign").pil_image
164+
return AutoProcessor.from_pretrained(MODEL_NAME)
145165

146166

147167
@pytest.fixture(scope="module")
148-
def sampling_params():
149-
from vllm import SamplingParams
168+
def stop_sign_image():
169+
from vllm.assets.image import ImageAsset
150170

151-
return SamplingParams(
152-
temperature=0.0,
153-
max_tokens=20,
154-
repetition_penalty=1.5,
155-
skip_special_tokens=False,
156-
)
171+
return ImageAsset("stop_sign").pil_image.convert("RGB")
157172

158173

159174
@pytest.mark.slow
160175
class TestFlorenceInference:
161-
def test_caption(self, florence2_llm, stop_sign_image, sampling_params):
162-
outputs = florence2_llm.generate(
163-
[
164-
{
165-
"prompt": "<DETAILED_CAPTION>",
166-
"multi_modal_data": {"image": stop_sign_image},
167-
}
168-
],
169-
sampling_params=sampling_params,
176+
# ------------------------------------------------------------------
177+
# Caption tasks — check for semantically meaningful keywords
178+
# ------------------------------------------------------------------
179+
180+
def test_caption(self, florence2_llm, florence2_processor, stop_sign_image):
181+
result = _run_task(florence2_llm, florence2_processor, stop_sign_image, "<CAPTION>", max_tokens=30)
182+
text = result["<CAPTION>"].lower()
183+
assert "car" in text or "stop" in text, f"<CAPTION> output missing expected content: {text!r}"
184+
185+
def test_detailed_caption(self, florence2_llm, florence2_processor, stop_sign_image):
186+
result = _run_task(florence2_llm, florence2_processor, stop_sign_image, "<DETAILED_CAPTION>", max_tokens=80)
187+
text = result["<DETAILED_CAPTION>"].lower()
188+
# Must mention the car and give some background detail — guards against the
189+
# KV-cache encoder_seq_lens regression that previously produced garbled output.
190+
assert "car" in text, f"<DETAILED_CAPTION> missing 'car': {text!r}"
191+
assert len(text.split()) >= 10, f"<DETAILED_CAPTION> too short: {text!r}"
192+
193+
def test_more_detailed_caption(self, florence2_llm, florence2_processor, stop_sign_image):
194+
result = _run_task(florence2_llm, florence2_processor, stop_sign_image, "<MORE_DETAILED_CAPTION>", max_tokens=100)
195+
text = result["<MORE_DETAILED_CAPTION>"].lower()
196+
assert "stop sign" in text or "sign" in text, f"<MORE_DETAILED_CAPTION> missing 'stop sign': {text!r}"
197+
assert len(text.split()) >= 10, f"<MORE_DETAILED_CAPTION> too short: {text!r}"
198+
199+
# ------------------------------------------------------------------
200+
# Structured-output tasks — check schema and key labels
201+
# ------------------------------------------------------------------
202+
203+
def test_object_detection(self, florence2_llm, florence2_processor, stop_sign_image):
204+
result = _run_task(florence2_llm, florence2_processor, stop_sign_image, "<OD>", max_tokens=300)
205+
od = result["<OD>"]
206+
assert "bboxes" in od and "labels" in od
207+
assert len(od["bboxes"]) == len(od["labels"]) > 0
208+
# Each bbox must be a 4-element list with non-negative coords
209+
for bbox in od["bboxes"]:
210+
assert len(bbox) == 4 and all(c >= 0 for c in bbox)
211+
labels = od["labels"]
212+
assert "stop sign" in labels, f"Expected 'stop sign' in OD labels, got: {labels}"
213+
assert "car" in labels or "building" in labels, f"Expected common objects in OD labels, got: {labels}"
214+
215+
def test_dense_region_caption(self, florence2_llm, florence2_processor, stop_sign_image):
216+
result = _run_task(florence2_llm, florence2_processor, stop_sign_image, "<DENSE_REGION_CAPTION>", max_tokens=250)
217+
drc = result["<DENSE_REGION_CAPTION>"]
218+
assert "bboxes" in drc and "labels" in drc
219+
assert len(drc["bboxes"]) == len(drc["labels"]) > 0
220+
assert "stop sign" in drc["labels"], f"Expected 'stop sign' in dense captions, got: {drc['labels']}"
221+
222+
def test_region_proposal(self, florence2_llm, florence2_processor, stop_sign_image):
223+
result = _run_task(florence2_llm, florence2_processor, stop_sign_image, "<REGION_PROPOSAL>", max_tokens=100)
224+
rp = result["<REGION_PROPOSAL>"]
225+
assert "bboxes" in rp and "labels" in rp
226+
assert len(rp["bboxes"]) > 0
227+
# Region proposal labels are always empty strings
228+
assert all(label == "" for label in rp["labels"])
229+
230+
def test_ocr_with_region(self, florence2_llm, florence2_processor, stop_sign_image):
231+
result = _run_task(florence2_llm, florence2_processor, stop_sign_image, "<OCR_WITH_REGION>", max_tokens=250)
232+
ocr = result["<OCR_WITH_REGION>"]
233+
assert "quad_boxes" in ocr and "labels" in ocr
234+
assert len(ocr["quad_boxes"]) == len(ocr["labels"]) > 0
235+
# Each quad box must be 8 coords
236+
for quad in ocr["quad_boxes"]:
237+
assert len(quad) == 8
238+
# "STOP" is the most prominent text in the image
239+
joined = " ".join(ocr["labels"])
240+
assert "STOP" in joined, f"Expected 'STOP' in OCR_WITH_REGION labels, got: {joined!r}"
241+
242+
def test_caption_to_phrase_grounding(self, florence2_llm, florence2_processor, stop_sign_image):
243+
result = _run_task(
244+
florence2_llm, florence2_processor, stop_sign_image,
245+
"<CAPTION_TO_PHRASE_GROUNDING>", text_input="A stop sign on a street corner.", max_tokens=80,
170246
)
171-
assert len(outputs[0].outputs[0].text) > 0
172-
173-
def test_object_detection_has_loc_tokens(
174-
self, florence2_llm, stop_sign_image, sampling_params
175-
):
176-
outputs = florence2_llm.generate(
177-
[
178-
{
179-
"encoder_prompt": {
180-
"prompt": "<OD>",
181-
"multi_modal_data": {"image": stop_sign_image},
182-
},
183-
"decoder_prompt": "",
184-
}
185-
],
186-
sampling_params=sampling_params,
247+
cpg = result["<CAPTION_TO_PHRASE_GROUNDING>"]
248+
assert "bboxes" in cpg and "labels" in cpg
249+
assert len(cpg["bboxes"]) > 0
250+
assert any("stop sign" in lbl.lower() for lbl in cpg["labels"]), (
251+
f"Expected 'stop sign' grounded, got labels: {cpg['labels']}"
187252
)
188-
assert "<loc_" in outputs[0].outputs[0].text
189253

190-
def test_batch_inference(self, florence2_llm, stop_sign_image, sampling_params):
254+
# ------------------------------------------------------------------
255+
# Batch tests
256+
# ------------------------------------------------------------------
257+
258+
def test_batch_inference(self, florence2_llm, florence2_processor, stop_sign_image):
259+
"""Multiple prompts in one batch must all produce non-empty output."""
260+
from vllm import SamplingParams
261+
262+
params = SamplingParams(temperature=0.0, max_tokens=30, skip_special_tokens=False)
191263
prompts = [
192264
{"prompt": "<CAPTION>", "multi_modal_data": {"image": stop_sign_image}},
193-
{
194-
"prompt": "<DETAILED_CAPTION>",
195-
"multi_modal_data": {"image": stop_sign_image},
196-
},
265+
{"prompt": "<DETAILED_CAPTION>", "multi_modal_data": {"image": stop_sign_image}},
197266
]
198-
outputs = florence2_llm.generate(prompts, sampling_params=sampling_params)
267+
outputs = florence2_llm.generate(prompts, sampling_params=params)
199268
assert all(len(o.outputs[0].text) > 0 for o in outputs)
200-
201-
def test_encoder_length_within_limit(self, stop_sign_image):
202-
"""Processor output must not exceed BART max_position_embeddings."""
203-
from transformers import AutoProcessor
204-
205-
processor = AutoProcessor.from_pretrained(MODEL_NAME)
206-
out = processor(
207-
text="<DETAILED_CAPTION>", images=stop_sign_image, return_tensors="pt"
208-
)
209-
assert out["input_ids"].shape[1] <= 1024

vllm_bart_plugin/florence2.py

Lines changed: 52 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -718,6 +718,19 @@ def get_dummy_mm_data(
718718

719719
class Florence2MultiModalProcessor(EncDecMultiModalProcessor[Florence2ProcessingInfo]):
720720

721+
def __init__(self, info, dummy_inputs, *, cache=None) -> None:
722+
super().__init__(info, dummy_inputs, cache=cache)
723+
# Florence2Config does not expose decoder_start_token_id at the
724+
# top level (it lives in text_config), so vLLM falls back to BOS
725+
# (token 0) and incorrectly prepends it to the decoder prompt.
726+
# Patch the top-level hf_config so vLLM's _prepare_decoder_input_ids
727+
# sees the real value (EOS / token 2) and leaves our prompt intact.
728+
hf_config = info.get_hf_config()
729+
if getattr(hf_config, "decoder_start_token_id", None) is None:
730+
hf_config.decoder_start_token_id = (
731+
hf_config.text_config.decoder_start_token_id
732+
)
733+
721734
def _hf_processor_applies_updates(
722735
self,
723736
prompt_text: str,
@@ -742,7 +755,16 @@ def create_decoder_prompt(
742755
prompt: str | list[int],
743756
mm_data: MultiModalDataDict,
744757
) -> str | list[int]:
745-
return [self.info.get_hf_config().text_config.eos_token_id]
758+
text_config = self.info.get_hf_config().text_config
759+
# Decoder prompt mirrors what transformers does before open-ended
760+
# generation: start with decoder_start_token_id (</s>, token 2),
761+
# then include forced_bos_token_id (<s>, token 0) so that vLLM
762+
# generates from the same position as transformers step 2.
763+
decoder_prompt = [text_config.decoder_start_token_id]
764+
forced_bos = getattr(text_config, "forced_bos_token_id", None)
765+
if forced_bos is not None:
766+
decoder_prompt.append(forced_bos)
767+
return decoder_prompt
746768

747769
def _apply_hf_processor_tokens_only(
748770
self,
@@ -793,20 +815,40 @@ def _get_prompt_updates(
793815
hf_processor_mm_kwargs: Mapping[str, object],
794816
out_mm_kwargs: MultiModalKwargsItems,
795817
) -> Sequence[PromptUpdate]:
796-
hf_config = self.info.get_hf_config()
797-
# Use image_token_id (51289) — this is what the Florence2Processor
798-
# inserts into input_ids. With _hf_processor_applies_updates=True,
799-
# vllm will FIND these tokens in the existing prompt rather than
800-
# inserting new ones (so no token doubling / length overflow).
801-
image_token_id = hf_config.image_token_id
802-
num_image_tokens = self.info.get_num_image_tokens()
803-
image_tokens = [image_token_id] * num_image_tokens
818+
# The placeholder must cover the FULL encoder input sequence (image
819+
# tokens + text/task tokens) so that vLLM's _get_encoder_seq_lens
820+
# computes the correct value for cross-attention KV cache allocation.
821+
# Using only the image token count (577) would cause cross-attention
822+
# to read only 577/590 K/V pairs, skipping the task-prompt tokens.
823+
#
824+
# With _hf_processor_applies_updates=True, vLLM detects the existing
825+
# token sequence rather than inserting new tokens. By setting the
826+
# insertion to the full encoder_input_ids sequence, the detected
827+
# placeholder range covers all 590 encoder tokens.
828+
insertion: list[int]
829+
image_items = out_mm_kwargs.get("image", [])
830+
if image_items:
831+
item_data = image_items[0].get_data()
832+
enc_ids = item_data.get("encoder_input_ids")
833+
if enc_ids is not None:
834+
insertion = enc_ids.tolist()
835+
else:
836+
# Cache hit: encoder_input_ids not available; fall back.
837+
hf_config = self.info.get_hf_config()
838+
insertion = (
839+
[hf_config.image_token_id] * self.info.get_num_image_tokens()
840+
)
841+
else:
842+
hf_config = self.info.get_hf_config()
843+
insertion = (
844+
[hf_config.image_token_id] * self.info.get_num_image_tokens()
845+
)
804846

805847
return [
806848
PromptInsertion(
807849
modality="image",
808850
target=PromptIndexTargets.start(),
809-
insertion=image_tokens,
851+
insertion=insertion,
810852
)
811853
]
812854

0 commit comments

Comments
 (0)