Skip to content

Commit b5334df

Browse files
committed
simplify by addressing comments
1 parent 19d1c6e commit b5334df

2 files changed

Lines changed: 40 additions & 449 deletions

File tree

tests/test_florence2.py

Lines changed: 21 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -30,55 +30,6 @@ def _small_vision_config():
3030
# ---------------------------------------------------------------------------
3131

3232

33-
class TestFlorenceVisionDropPath:
34-
def test_eval_is_identity(self):
35-
from vllm_bart_plugin.florence2 import Florence2VisionDropPath
36-
37-
m = Florence2VisionDropPath(drop_prob=0.9).eval()
38-
x = torch.randn(2, 16)
39-
assert torch.equal(m(x), x)
40-
41-
def test_training_drops_samples(self):
42-
from vllm_bart_plugin.florence2 import Florence2VisionDropPath
43-
44-
torch.manual_seed(0)
45-
m = Florence2VisionDropPath(drop_prob=0.5).train()
46-
out = m(torch.ones(64, 16))
47-
assert not torch.all(out == 1)
48-
49-
50-
class TestFlorenceVisionConvEmbed:
51-
@pytest.mark.parametrize("pre_norm", [True, False])
52-
def test_output_channels(self, pre_norm):
53-
from vllm_bart_plugin.florence2 import Florence2VisionConvEmbed
54-
55-
m = Florence2VisionConvEmbed(
56-
patch_size=7,
57-
in_channels=3,
58-
embed_dim=64,
59-
stride=4,
60-
padding=3,
61-
pre_norm=pre_norm,
62-
)
63-
out = m(torch.randn(1, 3, 64, 64))
64-
assert out.shape[1] == 64
65-
66-
67-
class TestFlorenceVisionWindowAttention:
68-
def test_exact_window(self):
69-
from vllm_bart_plugin.florence2 import Florence2VisionWindowAttention
70-
71-
m = Florence2VisionWindowAttention(dim=32, num_heads=4, window_size=4)
72-
assert m(torch.randn(1, 4, 4, 32)).shape == (1, 16, 32)
73-
74-
def test_input_requires_padding(self):
75-
from vllm_bart_plugin.florence2 import Florence2VisionWindowAttention
76-
77-
m = Florence2VisionWindowAttention(dim=32, num_heads=4, window_size=4)
78-
# 6 is not divisible by 4; output should still be (B, 6*6, C)
79-
assert m(torch.randn(1, 6, 6, 32)).shape == (1, 36, 32)
80-
81-
8233
class TestFlorenceVisionBackbone:
8334
def test_output_shape(self):
8435
from vllm_bart_plugin.florence2 import Florence2VisionBackbone
@@ -184,9 +135,9 @@ def test_caption(self, florence2_llm, florence2_processor, stop_sign_image):
184135
max_tokens=30,
185136
)
186137
text = result["<CAPTION>"].lower()
187-
assert "car" in text or "stop" in text, (
188-
f"<CAPTION> output missing expected content: {text!r}"
189-
)
138+
assert (
139+
"car" in text or "stop" in text
140+
), f"<CAPTION> output missing expected content: {text!r}"
190141

191142
def test_detailed_caption(
192143
self, florence2_llm, florence2_processor, stop_sign_image
@@ -215,9 +166,9 @@ def test_more_detailed_caption(
215166
max_tokens=100,
216167
)
217168
text = result["<MORE_DETAILED_CAPTION>"].lower()
218-
assert "stop sign" in text or "sign" in text, (
219-
f"<MORE_DETAILED_CAPTION> missing 'stop sign': {text!r}"
220-
)
169+
assert (
170+
"stop sign" in text or "sign" in text
171+
), f"<MORE_DETAILED_CAPTION> missing 'stop sign': {text!r}"
221172
assert len(text.split()) >= 10, f"<MORE_DETAILED_CAPTION> too short: {text!r}"
222173

223174
# ------------------------------------------------------------------
@@ -237,12 +188,12 @@ def test_object_detection(
237188
for bbox in od["bboxes"]:
238189
assert len(bbox) == 4 and all(c >= 0 for c in bbox)
239190
labels = od["labels"]
240-
assert "stop sign" in labels, (
241-
f"Expected 'stop sign' in OD labels, got: {labels}"
242-
)
243-
assert "car" in labels or "building" in labels, (
244-
f"Expected common objects in OD labels, got: {labels}"
245-
)
191+
assert (
192+
"stop sign" in labels
193+
), f"Expected 'stop sign' in OD labels, got: {labels}"
194+
assert (
195+
"car" in labels or "building" in labels
196+
), f"Expected common objects in OD labels, got: {labels}"
246197

247198
def test_dense_region_caption(
248199
self, florence2_llm, florence2_processor, stop_sign_image
@@ -257,9 +208,9 @@ def test_dense_region_caption(
257208
drc = result["<DENSE_REGION_CAPTION>"]
258209
assert "bboxes" in drc and "labels" in drc
259210
assert len(drc["bboxes"]) == len(drc["labels"]) > 0
260-
assert "stop sign" in drc["labels"], (
261-
f"Expected 'stop sign' in dense captions, got: {drc['labels']}"
262-
)
211+
assert (
212+
"stop sign" in drc["labels"]
213+
), f"Expected 'stop sign' in dense captions, got: {drc['labels']}"
263214

264215
def test_region_proposal(self, florence2_llm, florence2_processor, stop_sign_image):
265216
result = _run_task(
@@ -291,9 +242,9 @@ def test_ocr_with_region(self, florence2_llm, florence2_processor, stop_sign_ima
291242
assert len(quad) == 8
292243
# "STOP" is the most prominent text in the image
293244
joined = " ".join(ocr["labels"])
294-
assert "STOP" in joined, (
295-
f"Expected 'STOP' in OCR_WITH_REGION labels, got: {joined!r}"
296-
)
245+
assert (
246+
"STOP" in joined
247+
), f"Expected 'STOP' in OCR_WITH_REGION labels, got: {joined!r}"
297248

298249
def test_caption_to_phrase_grounding(
299250
self, florence2_llm, florence2_processor, stop_sign_image
@@ -309,9 +260,9 @@ def test_caption_to_phrase_grounding(
309260
cpg = result["<CAPTION_TO_PHRASE_GROUNDING>"]
310261
assert "bboxes" in cpg and "labels" in cpg
311262
assert len(cpg["bboxes"]) > 0
312-
assert any("stop sign" in lbl.lower() for lbl in cpg["labels"]), (
313-
f"Expected 'stop sign' grounded, got labels: {cpg['labels']}"
314-
)
263+
assert any(
264+
"stop sign" in lbl.lower() for lbl in cpg["labels"]
265+
), f"Expected 'stop sign' grounded, got labels: {cpg['labels']}"
315266

316267
# ------------------------------------------------------------------
317268
# Batch tests

0 commit comments

Comments
 (0)