Skip to content

Commit 7cdec11

Browse files
committed
add tests
Signed-off-by: Carles Onielfa <carlesonielfa@gmail.com>
1 parent b42b288 commit 7cdec11

3 files changed

Lines changed: 501 additions & 141 deletions

File tree

pyproject.toml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,20 @@ include = '\.pyi?$'
6262
profile = "black"
6363
line_length = 88
6464

65+
[tool.pytest.ini_options]
66+
markers = [
67+
"slow: marks tests requiring a GPU and full model download (deselect with '-m \"not slow\"')",
68+
]
69+
6570
[tool.mypy]
6671
python_version = "3.10"
6772
warn_return_any = true
6873
warn_unused_configs = true
6974
ignore_missing_imports = true
75+
76+
[dependency-groups]
77+
dev = [
78+
"black>=26.1.0",
79+
"isort>=8.0.1",
80+
"pytest>=9.0.2",
81+
]

tests/test_florence2.py

Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
"""Tests for the Florence-2 multimodal model plugin."""
2+
3+
import os
4+
5+
import pytest
6+
import torch
7+
from transformers import Florence2Config
8+
9+
MODEL_NAME = "florence-community/Florence-2-base-ft"
10+
11+
12+
def _small_vision_config():
13+
"""Tiny 1-stage Florence2 config for fast CPU tests."""
14+
cfg = Florence2Config()
15+
vc = cfg.vision_config
16+
vc.embed_dim = [64]
17+
vc.depths = [1]
18+
vc.num_heads = [4]
19+
vc.num_groups = [4]
20+
vc.patch_size = [7]
21+
vc.patch_stride = [4]
22+
vc.patch_padding = [3]
23+
vc.patch_prenorm = [False]
24+
vc.drop_path_rate = 0.0
25+
return cfg, vc
26+
27+
28+
# ---------------------------------------------------------------------------
29+
# Unit tests — vision architecture (CPU, no weights)
30+
# ---------------------------------------------------------------------------
31+
32+
33+
class TestFlorenceVisionDropPath:
34+
def test_eval_is_identity(self):
35+
from vllm_bart_plugin.florence2 import Florence2VisionDropPath
36+
37+
m = Florence2VisionDropPath(drop_prob=0.9).eval()
38+
x = torch.randn(2, 16)
39+
assert torch.equal(m(x), x)
40+
41+
def test_training_drops_samples(self):
42+
from vllm_bart_plugin.florence2 import Florence2VisionDropPath
43+
44+
torch.manual_seed(0)
45+
m = Florence2VisionDropPath(drop_prob=0.5).train()
46+
out = m(torch.ones(64, 16))
47+
assert not torch.all(out == 1)
48+
49+
50+
class TestFlorenceVisionConvEmbed:
51+
@pytest.mark.parametrize("pre_norm", [True, False])
52+
def test_output_channels(self, pre_norm):
53+
from vllm_bart_plugin.florence2 import Florence2VisionConvEmbed
54+
55+
m = Florence2VisionConvEmbed(
56+
patch_size=7,
57+
in_channels=3,
58+
embed_dim=64,
59+
stride=4,
60+
padding=3,
61+
pre_norm=pre_norm,
62+
)
63+
out = m(torch.randn(1, 3, 64, 64))
64+
assert out.shape[1] == 64
65+
66+
67+
class TestFlorenceVisionWindowAttention:
68+
def test_exact_window(self):
69+
from vllm_bart_plugin.florence2 import Florence2VisionWindowAttention
70+
71+
m = Florence2VisionWindowAttention(dim=32, num_heads=4, window_size=4)
72+
assert m(torch.randn(1, 4, 4, 32)).shape == (1, 16, 32)
73+
74+
def test_input_requires_padding(self):
75+
from vllm_bart_plugin.florence2 import Florence2VisionWindowAttention
76+
77+
m = Florence2VisionWindowAttention(dim=32, num_heads=4, window_size=4)
78+
# 6 is not divisible by 4; output should still be (B, 6*6, C)
79+
assert m(torch.randn(1, 6, 6, 32)).shape == (1, 36, 32)
80+
81+
82+
class TestFlorenceVisionBackbone:
83+
def test_output_shape(self):
84+
from vllm_bart_plugin.florence2 import Florence2VisionBackbone
85+
86+
_, vc = _small_vision_config()
87+
out = Florence2VisionBackbone(vc)(torch.randn(2, 3, 64, 64))
88+
assert out.shape == (2, vc.embed_dim[-1], 16, 16)
89+
90+
91+
class TestFlorenceVisionPositionalEmbeddingCosine1D:
92+
def test_output_shape_and_no_batch_dim(self):
93+
from vllm_bart_plugin.florence2 import (
94+
Florence2VisionPositionalEmbeddingCosine1D,
95+
)
96+
97+
m = Florence2VisionPositionalEmbeddingCosine1D(embed_dim=64, max_seq_len=100)
98+
assert m(torch.randn(2, 5, 64)).shape == (5, 64)
99+
100+
def test_raises_if_exceeds_max(self):
101+
from vllm_bart_plugin.florence2 import (
102+
Florence2VisionPositionalEmbeddingCosine1D,
103+
)
104+
105+
m = Florence2VisionPositionalEmbeddingCosine1D(embed_dim=64, max_seq_len=10)
106+
with pytest.raises(AssertionError):
107+
m(torch.randn(1, 20, 64))
108+
109+
110+
class TestFlorenceMultiModalProjector:
111+
def test_output_shape(self):
112+
from vllm_bart_plugin.florence2 import Florence2MultiModalProjector
113+
114+
cfg, vc = _small_vision_config()
115+
vc.projection_dim = 128
116+
m = Florence2MultiModalProjector(cfg)
117+
out = m(torch.randn(2, vc.embed_dim[-1], 12, 12))
118+
# (B, 1 spatial-avg token + H*W tokens, proj_dim)
119+
assert out.shape == (2, 1 + 12 * 12, vc.projection_dim)
120+
121+
122+
# ---------------------------------------------------------------------------
123+
# Integration tests — full model inference (GPU required)
124+
# ---------------------------------------------------------------------------
125+
126+
127+
@pytest.fixture(scope="module")
128+
def florence2_llm():
129+
from vllm import LLM
130+
131+
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
132+
return LLM(
133+
model=MODEL_NAME,
134+
trust_remote_code=True,
135+
enforce_eager=True,
136+
gpu_memory_utilization=0.5,
137+
mm_processor_cache_gb=0,
138+
)
139+
140+
141+
@pytest.fixture(scope="module")
142+
def stop_sign_image():
143+
from vllm.assets.image import ImageAsset
144+
145+
return ImageAsset("stop_sign").pil_image
146+
147+
148+
@pytest.fixture(scope="module")
149+
def sampling_params():
150+
from vllm import SamplingParams
151+
152+
return SamplingParams(
153+
temperature=0.0,
154+
max_tokens=20,
155+
repetition_penalty=1.5,
156+
skip_special_tokens=False,
157+
)
158+
159+
160+
@pytest.mark.slow
161+
class TestFlorenceInference:
162+
def test_caption(self, florence2_llm, stop_sign_image, sampling_params):
163+
outputs = florence2_llm.generate(
164+
[
165+
{
166+
"prompt": "<DETAILED_CAPTION>",
167+
"multi_modal_data": {"image": stop_sign_image},
168+
}
169+
],
170+
sampling_params=sampling_params,
171+
)
172+
assert len(outputs[0].outputs[0].text) > 0
173+
174+
def test_object_detection_has_loc_tokens(
175+
self, florence2_llm, stop_sign_image, sampling_params
176+
):
177+
outputs = florence2_llm.generate(
178+
[
179+
{
180+
"encoder_prompt": {
181+
"prompt": "<OD>",
182+
"multi_modal_data": {"image": stop_sign_image},
183+
},
184+
"decoder_prompt": "",
185+
}
186+
],
187+
sampling_params=sampling_params,
188+
)
189+
assert "<loc_" in outputs[0].outputs[0].text
190+
191+
def test_batch_inference(self, florence2_llm, stop_sign_image, sampling_params):
192+
prompts = [
193+
{"prompt": "<CAPTION>", "multi_modal_data": {"image": stop_sign_image}},
194+
{
195+
"prompt": "<DETAILED_CAPTION>",
196+
"multi_modal_data": {"image": stop_sign_image},
197+
},
198+
]
199+
outputs = florence2_llm.generate(prompts, sampling_params=sampling_params)
200+
assert all(len(o.outputs[0].text) > 0 for o in outputs)
201+
202+
def test_encoder_length_within_limit(self, stop_sign_image):
203+
"""Processor output must not exceed BART max_position_embeddings."""
204+
from transformers import AutoProcessor
205+
206+
processor = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True)
207+
out = processor(
208+
text="<DETAILED_CAPTION>", images=stop_sign_image, return_tensors="pt"
209+
)
210+
assert out["input_ids"].shape[1] <= 1024

0 commit comments

Comments
 (0)