66import torch
77from transformers import Florence2Config
88
9- MODEL_NAME = "florence-community/Florence-2-base-ft"
9+ # Allow override via env var so CI can point at a local checkpoint.
10+ MODEL_NAME = os .environ .get (
11+ "FLORENCE2_MODEL" ,
12+ os .path .abspath (
13+ os .path .join (os .path .dirname (__file__ ), "../../Florence-2-base-ft" )
14+ ),
15+ )
1016
1117
1218def _small_vision_config ():
@@ -124,6 +130,20 @@ def test_output_shape(self):
124130# ---------------------------------------------------------------------------
125131
126132
133+ def _run_task (llm , processor , image , task_prompt , text_input = None , max_tokens = 100 ):
134+ """Helper: run one Florence-2 task and return the post-processed result."""
135+ from vllm import SamplingParams
136+
137+ prompt = task_prompt if text_input is None else task_prompt + text_input
138+ params = SamplingParams (temperature = 0.0 , max_tokens = max_tokens , skip_special_tokens = False )
139+ outputs = llm .generate (
140+ [{"prompt" : prompt , "multi_modal_data" : {"image" : image }}],
141+ sampling_params = params ,
142+ )
143+ raw = outputs [0 ].outputs [0 ].text
144+ return processor .post_process_generation (raw , task = task_prompt , image_size = image .size )
145+
146+
127147@pytest .fixture (scope = "module" )
128148def florence2_llm ():
129149 from vllm import LLM
@@ -138,72 +158,111 @@ def florence2_llm():
138158
139159
140160@pytest .fixture (scope = "module" )
141- def stop_sign_image ():
142- from vllm . assets . image import ImageAsset
161+ def florence2_processor ():
162+ from transformers import AutoProcessor
143163
144- return ImageAsset ( "stop_sign" ). pil_image
164+ return AutoProcessor . from_pretrained ( MODEL_NAME )
145165
146166
147167@pytest .fixture (scope = "module" )
148- def sampling_params ():
149- from vllm import SamplingParams
168+ def stop_sign_image ():
169+ from vllm . assets . image import ImageAsset
150170
151- return SamplingParams (
152- temperature = 0.0 ,
153- max_tokens = 20 ,
154- repetition_penalty = 1.5 ,
155- skip_special_tokens = False ,
156- )
171+ return ImageAsset ("stop_sign" ).pil_image .convert ("RGB" )
157172
158173
159174@pytest .mark .slow
160175class TestFlorenceInference :
161- def test_caption (self , florence2_llm , stop_sign_image , sampling_params ):
162- outputs = florence2_llm .generate (
163- [
164- {
165- "prompt" : "<DETAILED_CAPTION>" ,
166- "multi_modal_data" : {"image" : stop_sign_image },
167- }
168- ],
169- sampling_params = sampling_params ,
176+ # ------------------------------------------------------------------
177+ # Caption tasks — check for semantically meaningful keywords
178+ # ------------------------------------------------------------------
179+
180+ def test_caption (self , florence2_llm , florence2_processor , stop_sign_image ):
181+ result = _run_task (florence2_llm , florence2_processor , stop_sign_image , "<CAPTION>" , max_tokens = 30 )
182+ text = result ["<CAPTION>" ].lower ()
183+ assert "car" in text or "stop" in text , f"<CAPTION> output missing expected content: { text !r} "
184+
185+ def test_detailed_caption (self , florence2_llm , florence2_processor , stop_sign_image ):
186+ result = _run_task (florence2_llm , florence2_processor , stop_sign_image , "<DETAILED_CAPTION>" , max_tokens = 80 )
187+ text = result ["<DETAILED_CAPTION>" ].lower ()
188+ # Must mention the car and give some background detail — guards against the
189+ # KV-cache encoder_seq_lens regression that previously produced garbled output.
190+ assert "car" in text , f"<DETAILED_CAPTION> missing 'car': { text !r} "
191+ assert len (text .split ()) >= 10 , f"<DETAILED_CAPTION> too short: { text !r} "
192+
193+ def test_more_detailed_caption (self , florence2_llm , florence2_processor , stop_sign_image ):
194+ result = _run_task (florence2_llm , florence2_processor , stop_sign_image , "<MORE_DETAILED_CAPTION>" , max_tokens = 100 )
195+ text = result ["<MORE_DETAILED_CAPTION>" ].lower ()
196+ assert "stop sign" in text or "sign" in text , f"<MORE_DETAILED_CAPTION> missing 'stop sign': { text !r} "
197+ assert len (text .split ()) >= 10 , f"<MORE_DETAILED_CAPTION> too short: { text !r} "
198+
199+ # ------------------------------------------------------------------
200+ # Structured-output tasks — check schema and key labels
201+ # ------------------------------------------------------------------
202+
203+ def test_object_detection (self , florence2_llm , florence2_processor , stop_sign_image ):
204+ result = _run_task (florence2_llm , florence2_processor , stop_sign_image , "<OD>" , max_tokens = 300 )
205+ od = result ["<OD>" ]
206+ assert "bboxes" in od and "labels" in od
207+ assert len (od ["bboxes" ]) == len (od ["labels" ]) > 0
208+ # Each bbox must be a 4-element list with non-negative coords
209+ for bbox in od ["bboxes" ]:
210+ assert len (bbox ) == 4 and all (c >= 0 for c in bbox )
211+ labels = od ["labels" ]
212+ assert "stop sign" in labels , f"Expected 'stop sign' in OD labels, got: { labels } "
213+ assert "car" in labels or "building" in labels , f"Expected common objects in OD labels, got: { labels } "
214+
215+ def test_dense_region_caption (self , florence2_llm , florence2_processor , stop_sign_image ):
216+ result = _run_task (florence2_llm , florence2_processor , stop_sign_image , "<DENSE_REGION_CAPTION>" , max_tokens = 250 )
217+ drc = result ["<DENSE_REGION_CAPTION>" ]
218+ assert "bboxes" in drc and "labels" in drc
219+ assert len (drc ["bboxes" ]) == len (drc ["labels" ]) > 0
220+ assert "stop sign" in drc ["labels" ], f"Expected 'stop sign' in dense captions, got: { drc ['labels' ]} "
221+
222+ def test_region_proposal (self , florence2_llm , florence2_processor , stop_sign_image ):
223+ result = _run_task (florence2_llm , florence2_processor , stop_sign_image , "<REGION_PROPOSAL>" , max_tokens = 100 )
224+ rp = result ["<REGION_PROPOSAL>" ]
225+ assert "bboxes" in rp and "labels" in rp
226+ assert len (rp ["bboxes" ]) > 0
227+ # Region proposal labels are always empty strings
228+ assert all (label == "" for label in rp ["labels" ])
229+
230+ def test_ocr_with_region (self , florence2_llm , florence2_processor , stop_sign_image ):
231+ result = _run_task (florence2_llm , florence2_processor , stop_sign_image , "<OCR_WITH_REGION>" , max_tokens = 250 )
232+ ocr = result ["<OCR_WITH_REGION>" ]
233+ assert "quad_boxes" in ocr and "labels" in ocr
234+ assert len (ocr ["quad_boxes" ]) == len (ocr ["labels" ]) > 0
235+ # Each quad box must be 8 coords
236+ for quad in ocr ["quad_boxes" ]:
237+ assert len (quad ) == 8
238+ # "STOP" is the most prominent text in the image
239+ joined = " " .join (ocr ["labels" ])
240+ assert "STOP" in joined , f"Expected 'STOP' in OCR_WITH_REGION labels, got: { joined !r} "
241+
242+ def test_caption_to_phrase_grounding (self , florence2_llm , florence2_processor , stop_sign_image ):
243+ result = _run_task (
244+ florence2_llm , florence2_processor , stop_sign_image ,
245+ "<CAPTION_TO_PHRASE_GROUNDING>" , text_input = "A stop sign on a street corner." , max_tokens = 80 ,
170246 )
171- assert len (outputs [0 ].outputs [0 ].text ) > 0
172-
173- def test_object_detection_has_loc_tokens (
174- self , florence2_llm , stop_sign_image , sampling_params
175- ):
176- outputs = florence2_llm .generate (
177- [
178- {
179- "encoder_prompt" : {
180- "prompt" : "<OD>" ,
181- "multi_modal_data" : {"image" : stop_sign_image },
182- },
183- "decoder_prompt" : "" ,
184- }
185- ],
186- sampling_params = sampling_params ,
247+ cpg = result ["<CAPTION_TO_PHRASE_GROUNDING>" ]
248+ assert "bboxes" in cpg and "labels" in cpg
249+ assert len (cpg ["bboxes" ]) > 0
250+ assert any ("stop sign" in lbl .lower () for lbl in cpg ["labels" ]), (
251+ f"Expected 'stop sign' grounded, got labels: { cpg ['labels' ]} "
187252 )
188- assert "<loc_" in outputs [0 ].outputs [0 ].text
189253
190- def test_batch_inference (self , florence2_llm , stop_sign_image , sampling_params ):
254+ # ------------------------------------------------------------------
255+ # Batch tests
256+ # ------------------------------------------------------------------
257+
258+ def test_batch_inference (self , florence2_llm , florence2_processor , stop_sign_image ):
259+ """Multiple prompts in one batch must all produce non-empty output."""
260+ from vllm import SamplingParams
261+
262+ params = SamplingParams (temperature = 0.0 , max_tokens = 30 , skip_special_tokens = False )
191263 prompts = [
192264 {"prompt" : "<CAPTION>" , "multi_modal_data" : {"image" : stop_sign_image }},
193- {
194- "prompt" : "<DETAILED_CAPTION>" ,
195- "multi_modal_data" : {"image" : stop_sign_image },
196- },
265+ {"prompt" : "<DETAILED_CAPTION>" , "multi_modal_data" : {"image" : stop_sign_image }},
197266 ]
198- outputs = florence2_llm .generate (prompts , sampling_params = sampling_params )
267+ outputs = florence2_llm .generate (prompts , sampling_params = params )
199268 assert all (len (o .outputs [0 ].text ) > 0 for o in outputs )
200-
201- def test_encoder_length_within_limit (self , stop_sign_image ):
202- """Processor output must not exceed BART max_position_embeddings."""
203- from transformers import AutoProcessor
204-
205- processor = AutoProcessor .from_pretrained (MODEL_NAME )
206- out = processor (
207- text = "<DETAILED_CAPTION>" , images = stop_sign_image , return_tensors = "pt"
208- )
209- assert out ["input_ids" ].shape [1 ] <= 1024
0 commit comments