diff --git a/README.md b/README.md
index 60efa88..6e0c11a 100644
--- a/README.md
+++ b/README.md
@@ -15,6 +15,13 @@ This plugin requires [uv](https://docs.astral.sh/uv/) for package management. If
```bash
curl -LsSf https://astral.sh/uv/install.sh | sh
```
+### From Git
+
+Install from git:
+
+```bash
+pip install git+https://github.com/vllm-project/bart-plugin
+```
### From Source
@@ -126,10 +133,8 @@ This plugin should work with any BART-based model from HuggingFace, including:
### Florence-2 Models
-- `microsoft/Florence-2-base`
-- `microsoft/Florence-2-large`
-
-Note: Florence-2 requires `trust_remote_code=True` and uses a separate tokenizer (`Isotr0py/Florence-2-tokenizer`).
+- `florence-community/Florence-2-base`
+- `florence-community/Florence-2-large`
## Evaluation
@@ -186,11 +191,14 @@ Notes:
```
bart-plugin/
├── vllm_bart_plugin/
-│ ├── __init__.py # Plugin registration
-│ └── bart.py # BART model implementation
-├── setup.py # Package configuration and entry points
-├── README.md # This file
-└── LICENSE # License file
+│ ├── __init__.py # Plugin registration
+│ └── bart.py # BART model implementation
+│ └── florence2.py # Florence-2 model implementation
+├── setup.py # Package configuration and entry points
+├── README.md # This file
+└── LICENSE # License file
+└── example_bart_usage.py # Example usage script for BART
+└── example_florence2_usage.py # Example usage script for Florence-2
```
### Running Tests
diff --git a/example_florence2_usage.py b/example_florence2_usage.py
index 6e9593b..32c9c96 100644
--- a/example_florence2_usage.py
+++ b/example_florence2_usage.py
@@ -5,28 +5,23 @@
This script demonstrates how to use Florence-2 models with vLLM
after installing the BART plugin.
"""
-import vllm_bart_plugin
+
from vllm import LLM, SamplingParams
from vllm.assets.image import ImageAsset
def main():
"""Run Florence-2 model examples."""
- model_name = "microsoft/Florence-2-large"
- tokenizer_name = "Isotr0py/Florence-2-tokenizer"
+ model_name = "florence-community/Florence-2-large-ft"
llm = LLM(
model=model_name,
- tokenizer=tokenizer_name,
mm_processor_cache_gb=0,
- trust_remote_code=True,
enforce_eager=True,
)
params = SamplingParams(
temperature=0.0,
max_tokens=20,
- # repetition_penalty is needed to prevent repetition
- repetition_penalty=1.5,
# skip_special_tokens=False is needed to present
# grounding tokens like
skip_special_tokens=False,
@@ -60,4 +55,4 @@ def main():
if __name__ == "__main__":
- main()
\ No newline at end of file
+ main()
diff --git a/pyproject.toml b/pyproject.toml
index c2627f2..e1789ea 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "vllm-bart-plugin"
-version = "0.2.0"
+version = "0.3.0"
description = "BART model plugin for vLLM"
readme = "README.md"
requires-python = ">=3.10"
@@ -26,9 +26,9 @@ classifiers = [
]
dependencies = [
- "vllm>=0.14.0",
+ "vllm>=0.16.0",
"torch>=2.9.0",
- "transformers>=4.56.0,<5",
+ "transformers>=4.56.0,<6",
]
[project.optional-dependencies]
@@ -62,8 +62,20 @@ include = '\.pyi?$'
profile = "black"
line_length = 88
+[tool.pytest.ini_options]
+markers = [
+ "slow: marks tests requiring a GPU and full model download (deselect with '-m \"not slow\"')",
+]
+
[tool.mypy]
python_version = "3.10"
warn_return_any = true
warn_unused_configs = true
ignore_missing_imports = true
+
+[dependency-groups]
+dev = [
+ "black>=26.1.0",
+ "isort>=8.0.1",
+ "pytest>=9.0.2",
+]
diff --git a/tests/test_florence2.py b/tests/test_florence2.py
new file mode 100644
index 0000000..38e5731
--- /dev/null
+++ b/tests/test_florence2.py
@@ -0,0 +1,267 @@
+"""Tests for the Florence-2 multimodal model plugin."""
+
+import os
+
+import pytest
+import torch
+from transformers import Florence2Config
+
+MODEL_NAME = "florence-community/Florence-2-base-ft"
+
+
+def _small_vision_config():
+ """Tiny 1-stage Florence2 config for fast CPU tests."""
+ cfg = Florence2Config()
+ vc = cfg.vision_config
+ vc.embed_dim = [64]
+ vc.depths = [1]
+ vc.num_heads = [4]
+ vc.num_groups = [4]
+ vc.patch_size = [7]
+ vc.patch_stride = [4]
+ vc.patch_padding = [3]
+ vc.patch_prenorm = [False]
+ vc.drop_path_rate = 0.0
+ return cfg, vc
+
+
+# ---------------------------------------------------------------------------
+# Unit tests — vision architecture (CPU, no weights)
+# ---------------------------------------------------------------------------
+
+
+class TestFlorenceVisionBackbone:
+ def test_output_shape(self):
+ from vllm_bart_plugin.florence2 import Florence2VisionBackbone
+
+ _, vc = _small_vision_config()
+ out = Florence2VisionBackbone(vc)(torch.randn(2, 3, 64, 64))
+ assert out.shape == (2, vc.embed_dim[-1], 16, 16)
+
+
+class TestFlorenceMultiModalProjector:
+ def test_output_shape(self):
+ from vllm_bart_plugin.florence2 import Florence2MultiModalProjector
+
+ cfg, vc = _small_vision_config()
+ vc.projection_dim = 128
+ m = Florence2MultiModalProjector(cfg)
+ out = m(torch.randn(2, vc.embed_dim[-1], 12, 12))
+ # (B, 1 spatial-avg token + H*W tokens, proj_dim)
+ assert out.shape == (2, 1 + 12 * 12, vc.projection_dim)
+
+
+# ---------------------------------------------------------------------------
+# Integration tests — full model inference (GPU required)
+# ---------------------------------------------------------------------------
+
+
+def _run_task(llm, processor, image, task_prompt, text_input=None, max_tokens=100):
+ """Helper: run one Florence-2 task and return the post-processed result."""
+ from vllm import SamplingParams
+
+ prompt = task_prompt if text_input is None else task_prompt + text_input
+ params = SamplingParams(
+ temperature=0.0, max_tokens=max_tokens, skip_special_tokens=False
+ )
+ outputs = llm.generate(
+ [{"prompt": prompt, "multi_modal_data": {"image": image}}],
+ sampling_params=params,
+ )
+ raw = outputs[0].outputs[0].text
+ return processor.post_process_generation(
+ raw, task=task_prompt, image_size=image.size
+ )
+
+
+@pytest.fixture(scope="module")
+def florence2_llm():
+ from vllm import LLM
+
+ os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
+ return LLM(
+ model=MODEL_NAME,
+ enforce_eager=True,
+ gpu_memory_utilization=0.5,
+ mm_processor_cache_gb=0,
+ )
+
+
+@pytest.fixture(scope="module")
+def florence2_processor():
+ from transformers import AutoProcessor
+
+ return AutoProcessor.from_pretrained(MODEL_NAME)
+
+
+@pytest.fixture(scope="module")
+def stop_sign_image():
+ from vllm.assets.image import ImageAsset
+
+ return ImageAsset("stop_sign").pil_image.convert("RGB")
+
+
+@pytest.mark.slow
+class TestFlorenceInference:
+ # ------------------------------------------------------------------
+ # Caption tasks — check for semantically meaningful keywords
+ # ------------------------------------------------------------------
+
+ def test_caption(self, florence2_llm, florence2_processor, stop_sign_image):
+ result = _run_task(
+ florence2_llm,
+ florence2_processor,
+ stop_sign_image,
+ "",
+ max_tokens=30,
+ )
+ text = result[""].lower()
+ assert (
+ "car" in text or "stop" in text
+ ), f" output missing expected content: {text!r}"
+
+ def test_detailed_caption(
+ self, florence2_llm, florence2_processor, stop_sign_image
+ ):
+ result = _run_task(
+ florence2_llm,
+ florence2_processor,
+ stop_sign_image,
+ "",
+ max_tokens=80,
+ )
+ text = result[""].lower()
+ # Must mention the car and give some background detail — guards against the
+ # KV-cache encoder_seq_lens regression that previously produced garbled output.
+ assert "car" in text, f" missing 'car': {text!r}"
+ assert len(text.split()) >= 10, f" too short: {text!r}"
+
+ def test_more_detailed_caption(
+ self, florence2_llm, florence2_processor, stop_sign_image
+ ):
+ result = _run_task(
+ florence2_llm,
+ florence2_processor,
+ stop_sign_image,
+ "",
+ max_tokens=100,
+ )
+ text = result[""].lower()
+ assert (
+ "stop sign" in text or "sign" in text
+ ), f" missing 'stop sign': {text!r}"
+ assert len(text.split()) >= 10, f" too short: {text!r}"
+
+ # ------------------------------------------------------------------
+ # Structured-output tasks — check schema and key labels
+ # ------------------------------------------------------------------
+
+ def test_object_detection(
+ self, florence2_llm, florence2_processor, stop_sign_image
+ ):
+ result = _run_task(
+ florence2_llm, florence2_processor, stop_sign_image, "", max_tokens=300
+ )
+ od = result[""]
+ assert "bboxes" in od and "labels" in od
+ assert len(od["bboxes"]) == len(od["labels"]) > 0
+ # Each bbox must be a 4-element list with non-negative coords
+ for bbox in od["bboxes"]:
+ assert len(bbox) == 4 and all(c >= 0 for c in bbox)
+ labels = od["labels"]
+ assert (
+ "stop sign" in labels
+ ), f"Expected 'stop sign' in OD labels, got: {labels}"
+ assert (
+ "car" in labels or "building" in labels
+ ), f"Expected common objects in OD labels, got: {labels}"
+
+ def test_dense_region_caption(
+ self, florence2_llm, florence2_processor, stop_sign_image
+ ):
+ result = _run_task(
+ florence2_llm,
+ florence2_processor,
+ stop_sign_image,
+ "",
+ max_tokens=250,
+ )
+ drc = result[""]
+ assert "bboxes" in drc and "labels" in drc
+ assert len(drc["bboxes"]) == len(drc["labels"]) > 0
+ assert (
+ "stop sign" in drc["labels"]
+ ), f"Expected 'stop sign' in dense captions, got: {drc['labels']}"
+
+ def test_region_proposal(self, florence2_llm, florence2_processor, stop_sign_image):
+ result = _run_task(
+ florence2_llm,
+ florence2_processor,
+ stop_sign_image,
+ "",
+ max_tokens=100,
+ )
+ rp = result[""]
+ assert "bboxes" in rp and "labels" in rp
+ assert len(rp["bboxes"]) > 0
+ # Region proposal labels are always empty strings
+ assert all(label == "" for label in rp["labels"])
+
+ def test_ocr_with_region(self, florence2_llm, florence2_processor, stop_sign_image):
+ result = _run_task(
+ florence2_llm,
+ florence2_processor,
+ stop_sign_image,
+ "",
+ max_tokens=250,
+ )
+ ocr = result[""]
+ assert "quad_boxes" in ocr and "labels" in ocr
+ assert len(ocr["quad_boxes"]) == len(ocr["labels"]) > 0
+ # Each quad box must be 8 coords
+ for quad in ocr["quad_boxes"]:
+ assert len(quad) == 8
+ # "STOP" is the most prominent text in the image
+ joined = " ".join(ocr["labels"])
+ assert (
+ "STOP" in joined
+ ), f"Expected 'STOP' in OCR_WITH_REGION labels, got: {joined!r}"
+
+ def test_caption_to_phrase_grounding(
+ self, florence2_llm, florence2_processor, stop_sign_image
+ ):
+ result = _run_task(
+ florence2_llm,
+ florence2_processor,
+ stop_sign_image,
+ "",
+ text_input="A stop sign on a street corner.",
+ max_tokens=80,
+ )
+ cpg = result[""]
+ assert "bboxes" in cpg and "labels" in cpg
+ assert len(cpg["bboxes"]) > 0
+ assert any(
+ "stop sign" in lbl.lower() for lbl in cpg["labels"]
+ ), f"Expected 'stop sign' grounded, got labels: {cpg['labels']}"
+
+ # ------------------------------------------------------------------
+ # Batch tests
+ # ------------------------------------------------------------------
+
+ def test_batch_inference(self, florence2_llm, florence2_processor, stop_sign_image):
+ """Multiple prompts in one batch must all produce non-empty output."""
+ from vllm import SamplingParams
+
+ params = SamplingParams(
+ temperature=0.0, max_tokens=30, skip_special_tokens=False
+ )
+ prompts = [
+ {"prompt": "", "multi_modal_data": {"image": stop_sign_image}},
+ {
+ "prompt": "",
+ "multi_modal_data": {"image": stop_sign_image},
+ },
+ ]
+ outputs = florence2_llm.generate(prompts, sampling_params=params)
+ assert all(len(o.outputs[0].text) > 0 for o in outputs)
diff --git a/tests/test_model_initialization.py b/tests/test_model_initialization.py
index 2a44aec..7f01bec 100644
--- a/tests/test_model_initialization.py
+++ b/tests/test_model_initialization.py
@@ -1,7 +1,8 @@
"""Tests for BART model initialization."""
import pytest
-import torch
+import vllm_bart_plugin
+vllm_bart_plugin.register_bart_model()
from vllm import LLM
@@ -42,6 +43,7 @@ def test_model_with_custom_config(self, small_model_name):
except Exception as e:
pytest.fail(f"Failed to load model with config: {e}")
+ @pytest.mark.slow
def test_model_class_initialization(self):
"""Test that model class can be instantiated."""
from vllm_bart_plugin.bart import BartForConditionalGeneration
@@ -53,7 +55,6 @@ def test_model_class_initialization(self):
model_config = ModelConfig(
model="facebook/bart-large-cnn",
- task="generate",
tokenizer="facebook/bart-large-cnn",
tokenizer_mode="auto",
trust_remote_code=False,
@@ -65,7 +66,6 @@ def test_model_class_initialization(self):
cache_config = CacheConfig(
block_size=16,
gpu_memory_utilization=0.3,
- swap_space_bytes=0,
cache_dtype="auto",
)
@@ -77,7 +77,9 @@ def test_model_class_initialization(self):
# Try to instantiate the model
try:
- model = BartForConditionalGeneration(vllm_config=vllm_config)
+ from vllm.config import set_current_vllm_config
+ with set_current_vllm_config(vllm_config):
+ model = BartForConditionalGeneration(vllm_config=vllm_config)
assert model is not None
assert hasattr(model, 'model')
assert hasattr(model, 'lm_head')
@@ -92,13 +94,14 @@ def test_model_has_required_methods(self):
'forward',
'compute_logits',
'load_weights',
- 'get_multimodal_embeddings',
+ 'embed_multimodal',
]
for method in required_methods:
assert hasattr(BartForConditionalGeneration, method), \
f"Model missing required method: {method}"
+ @pytest.mark.slow
def test_encoder_decoder_structure(self):
"""Test that BART has proper encoder-decoder structure."""
from vllm_bart_plugin.bart import BartModel, BartEncoder, BartDecoder
@@ -109,7 +112,6 @@ def test_encoder_decoder_structure(self):
model_config = ModelConfig(
model="facebook/bart-large-cnn",
- task="generate",
tokenizer="facebook/bart-large-cnn",
tokenizer_mode="auto",
trust_remote_code=False,
@@ -121,7 +123,6 @@ def test_encoder_decoder_structure(self):
cache_config = CacheConfig(
block_size=16,
gpu_memory_utilization=0.3,
- swap_space_bytes=0,
cache_dtype="auto",
)
@@ -131,7 +132,9 @@ def test_encoder_decoder_structure(self):
load_config=LoadConfig(),
)
- model = BartModel(vllm_config=vllm_config)
+ from vllm.config import set_current_vllm_config
+ with set_current_vllm_config(vllm_config):
+ model = BartModel(vllm_config=vllm_config)
assert hasattr(model, 'encoder')
assert hasattr(model, 'decoder')
diff --git a/vllm_bart_plugin/__init__.py b/vllm_bart_plugin/__init__.py
index 0648366..df079d7 100644
--- a/vllm_bart_plugin/__init__.py
+++ b/vllm_bart_plugin/__init__.py
@@ -34,10 +34,10 @@ def register_bart_model() -> None:
"vllm_bart_plugin.florence2:Florence2ForConditionalGeneration",
)
- logger.info("Successfully registered BART model with vLLM")
+ logger.info("Successfully registered BART and Florence2 models with vLLM")
except Exception as e:
- logger.error(f"Failed to register BART model: {e}")
+ logger.error(f"Failed to register BART and Florence2 models: {e}")
raise
diff --git a/vllm_bart_plugin/bart.py b/vllm_bart_plugin/bart.py
index f3d6cd5..a21e3fc 100644
--- a/vllm_bart_plugin/bart.py
+++ b/vllm_bart_plugin/bart.py
@@ -29,7 +29,8 @@
from torch import nn
from transformers import BartConfig
from transformers.utils import logging
-from vllm.attention.layer import Attention, AttentionType
+from vllm.model_executor.layers.attention import Attention
+from vllm.v1.attention.backend import AttentionType
from vllm.config import CacheConfig, VllmConfig
from vllm.config.lora import LoRAConfig
from vllm.config.multimodal import BaseDummyOptions
@@ -78,7 +79,7 @@
EncDecMultiModalProcessor,
PromptUpdate,
)
-from vllm.multimodal.profiling import BaseDummyInputsBuilder
+from vllm.multimodal.processing.dummy_inputs import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils.collection_utils import is_list_of
@@ -927,6 +928,9 @@ def get_mm_max_tokens_per_item(
config = self.get_hf_config()
return {"text": config.max_position_embeddings}
+ def get_data_parser(self) -> "MultiModalDataParser":
+ return TextDataParser()
+
class BartDummyInputsBuilder(BaseDummyInputsBuilder[BartProcessingInfo]):
"""Builds dummy inputs for profiling BART models."""
@@ -1107,7 +1111,7 @@ def _get_prompt_updates(
)
]
- def _get_data_parser(self) -> MultiModalDataParser:
+ def build_data_parser(self) -> MultiModalDataParser:
return TextDataParser()
@@ -1152,8 +1156,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config.vocab_size, config.d_model, embed_scale=embed_scale
)
# Bias added to logits after lm_head, matching HuggingFace approach
- self.register_buffer("final_logits_bias",
- torch.zeros((1, config.vocab_size)))
+ self.register_buffer("final_logits_bias", torch.zeros((1, config.vocab_size)))
self.logits_processor = LogitsProcessor(
self.unpadded_vocab_size, config.vocab_size
)
diff --git a/vllm_bart_plugin/florence2.py b/vllm_bart_plugin/florence2.py
index 3624b1f..b215dc1 100644
--- a/vllm_bart_plugin/florence2.py
+++ b/vllm_bart_plugin/florence2.py
@@ -1,66 +1,56 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-from typing import Literal, TypedDict, OrderedDict
import math
from collections.abc import Iterable, Mapping, Sequence
-from typing import Any
-
-import torch.nn.functional as F
-from einops import rearrange
+from typing import Literal, TypedDict
import torch
from torch import nn
-from transformers import BartConfig, BatchFeature, BartTokenizer, PretrainedConfig
-from transformers.utils import logging
-
-from vllm.attention.layer import Attention, AttentionType
-from vllm.model_executor.layers.attention.cross_attention import CrossAttention
-from vllm.model_executor.layers.attention.mm_encoder_attention import MMEncoderAttention
-from vllm.config import CacheConfig, VllmConfig
-from vllm.config.lora import LoRAConfig
-from vllm.config.multimodal import BaseDummyOptions
-from vllm.distributed import get_tensor_model_parallel_world_size
-from vllm.model_executor.layers.activation import get_act_fn
-from vllm.model_executor.layers.linear import (
- ColumnParallelLinear,
- QKVParallelLinear,
- RowParallelLinear,
+from transformers import (
+ BartTokenizer,
+ BatchFeature,
+ Florence2Config,
+ Florence2Processor,
)
-from vllm.model_executor.layers.logits_processor import LogitsProcessor
-from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
-from vllm.model_executor.layers.vocab_parallel_embedding import (
- ParallelLMHead,
- VocabParallelEmbedding,
+from transformers.models.florence2.modeling_florence2 import (
+ Florence2MultiModalProjector,
+ Florence2VisionBackbone,
)
+from vllm.config import VllmConfig
+from vllm.config.multimodal import BaseDummyOptions
+from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
-from vllm.multimodal import MULTIMODAL_REGISTRY, ModalityData
+from vllm.model_executor.models.interfaces import (
+ MultiModalEmbeddings,
+ SupportsMultiModal,
+)
+from vllm.model_executor.models.utils import (
+ AutoWeightsLoader,
+ WeightsMapper,
+)
+from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalFieldConfig,
MultiModalKwargsItems,
)
-from vllm.multimodal.parse import (
- ModalityDataItems,
- ModalityDataParser,
- MultiModalDataItems,
- MultiModalDataParser,
- ProcessorBatchItems,
-)
+from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (
BaseProcessingInfo,
EncDecMultiModalProcessor,
- PromptUpdate,
- PromptInsertion,
PromptIndexTargets,
+ PromptInsertion,
+ PromptUpdate,
)
-from vllm.multimodal.profiling import BaseDummyInputsBuilder
+from vllm.multimodal.processing.dummy_inputs import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
-from vllm.utils.collection_utils import is_list_of
-
-from vllm.model_executor.models.interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsQuant
-from vllm.model_executor.models.utils import AutoWeightsLoader, WeightsMapper, cast_overflow_tensors, maybe_prefix, flatten_bn
-from vllm_bart_plugin.bart import BartDecoder, BartEncoder, BartParallelLMHead, BartScaledWordEmbedding
+from vllm_bart_plugin.bart import (
+ BartDecoder,
+ BartEncoder,
+ BartParallelLMHead,
+ BartScaledWordEmbedding,
+)
class Florence2ImagePixelInputs(TypedDict):
@@ -69,547 +59,6 @@ class Florence2ImagePixelInputs(TypedDict):
"""Shape: (batch_size, num_channel, height, width)"""
-# ViT implementation are all copied from
-# https://huggingface.co/microsoft/Florence-2-base/blob/main/modeling_florence2.py
-class LearnedAbsolutePositionEmbedding2D(nn.Module):
- """
- This module learns positional embeddings up to a fixed maximum size.
- """
-
- def __init__(self, embedding_dim=256, num_pos=50):
- super().__init__()
- self.row_embeddings = nn.Embedding(num_pos, embedding_dim // 2)
- self.column_embeddings = nn.Embedding(
- num_pos, embedding_dim - (embedding_dim // 2))
-
- def forward(self, pixel_values):
- """
- pixel_values: (batch_size, height, width, num_channels)
- returns: (batch_size, height, width, embedding_dim * 2)
- """
- if len(pixel_values.shape) != 4:
- raise ValueError('pixel_values must be a 4D tensor')
- height, width = pixel_values.shape[1:3]
- width_values = torch.arange(width, device=pixel_values.device)
- height_values = torch.arange(height, device=pixel_values.device)
- x_emb = self.column_embeddings(width_values)
- y_emb = self.row_embeddings(height_values)
- # (height, width, embedding_dim * 2)
- pos = torch.cat([
- x_emb.unsqueeze(0).repeat(height, 1, 1),
- y_emb.unsqueeze(1).repeat(1, width, 1)
- ],
- dim=-1)
- # (embedding_dim * 2, height, width)
- pos = pos.permute(2, 0, 1)
- pos = pos.unsqueeze(0)
- # (batch_size, embedding_dim * 2, height, width)
- pos = pos.repeat(pixel_values.shape[0], 1, 1, 1)
- # (batch_size, height, width, embedding_dim * 2)
- pos = pos.permute(0, 2, 3, 1)
- return pos
-
-
-class PositionalEmbeddingCosine1D(nn.Module):
- """
- This class implements a very simple positional encoding. It follows closely
- the encoder from the link below:
- https://pytorch.org/tutorials/beginner/translation_transformer.html
- Args:
- embed_dim: The dimension of the embeddings.
- dropout_prob: The dropout probability.
- max_seq_len: The maximum length to precompute the positional encodings.
- """
-
- def __init__(self, embed_dim: int = 512, max_seq_len: int = 1024) -> None:
- super().__init__()
- self.embed_dim = embed_dim
- self.max_seq_len = max_seq_len
- # Generate the sinusoidal arrays.
- factor = math.log(10000)
- denominator = torch.exp(-factor * torch.arange(0, self.embed_dim, 2) /
- self.embed_dim)
- # Matrix where rows correspond to a positional embedding as a function
- # of the position index (i.e., the row index).
- frequencies = \
- torch.arange(0, self.max_seq_len) \
- .reshape(self.max_seq_len, 1) * denominator
- pos_idx_to_embed = torch.zeros((self.max_seq_len, self.embed_dim))
- # Populate uneven entries.
- pos_idx_to_embed[:, 0::2] = torch.sin(frequencies)
- pos_idx_to_embed[:, 1::2] = torch.cos(frequencies)
- # Save the positional embeddings in a constant buffer.
- # self.register_buffer("pos_idx_to_embed", pos_idx_to_embed)
- self.pos_idx_to_embed = nn.Parameter(pos_idx_to_embed,
- requires_grad=False)
-
- def forward(self, seq_embeds: torch.Tensor) -> torch.Tensor:
- """
- Args:
- seq_embeds: The sequence embeddings in order. Allowed size:
- 1. [T, D], where T is the length of the sequence, and D is the
- frame embedding dimension.
- 2. [B, T, D], where B is the batch size and T and D are the
- same as above.
- Returns a tensor of with the same dimensions as the input: i.e.,
- [1, T, D] or [T, D].
- """
- shape_len = len(seq_embeds.shape)
- assert 2 <= shape_len <= 3
- len_seq = seq_embeds.size(-2)
- assert len_seq <= self.max_seq_len
- pos_embeds = self.pos_idx_to_embed[0:seq_embeds.size(-2), :]
- # Adapt pre-computed positional embeddings to the input.
- if shape_len == 3:
- pos_embeds = pos_embeds.view(
- (1, pos_embeds.size(0), pos_embeds.size(1)))
- return pos_embeds
-
-
-class MySequential(nn.Sequential):
-
- def forward(self, *inputs):
- for module in self._modules.values():
- if isinstance(inputs, tuple):
- inputs = module(*inputs)
- else:
- inputs = module(inputs)
- return inputs
-
-
-class PreNorm(nn.Module):
-
- def __init__(self, norm, fn):
- super().__init__()
- self.norm = norm
- self.fn = fn
-
- def forward(self, x, *args, **kwargs):
- shortcut = x
- if self.norm is not None:
- x, size = self.fn(self.norm(x), *args, **kwargs)
- else:
- x, size = self.fn(x, *args, **kwargs)
-
- x = shortcut + x
-
- return x, size
-
-
-class Mlp(nn.Module):
-
- def __init__(
- self,
- in_features,
- hidden_features=None,
- out_features=None,
- act_layer=nn.GELU,
- ):
- super().__init__()
- out_features = out_features or in_features
- hidden_features = hidden_features or in_features
- self.net = nn.Sequential(
- OrderedDict([("fc1", nn.Linear(in_features, hidden_features)),
- ("act", act_layer()),
- ("fc2", nn.Linear(hidden_features, out_features))]))
-
- def forward(self, x, size):
- return self.net(x), size
-
-
-class DepthWiseConv2d(nn.Module):
-
- def __init__(
- self,
- dim_in,
- kernel_size,
- padding,
- stride,
- bias=True,
- ):
- super().__init__()
- self.dw = nn.Conv2d(dim_in,
- dim_in,
- kernel_size=kernel_size,
- padding=padding,
- groups=dim_in,
- stride=stride,
- bias=bias)
-
- def forward(self, x, size):
- B, N, C = x.shape
- H, W = size
- assert N == H * W
-
- x = self.dw(x.transpose(1, 2).view(B, C, H, W))
- size = (x.size(-2), x.size(-1))
- x = x.flatten(2).transpose(1, 2)
- return x, size
-
-
-class ConvEmbed(nn.Module):
- """ Image to Patch Embedding
- """
-
- def __init__(self,
- patch_size=7,
- in_chans=3,
- embed_dim=64,
- stride=4,
- padding=2,
- norm_layer=None,
- pre_norm=True):
- super().__init__()
- self.patch_size = patch_size
-
- self.proj = nn.Conv2d(in_chans,
- embed_dim,
- kernel_size=patch_size,
- stride=stride,
- padding=padding)
-
- dim_norm = in_chans if pre_norm else embed_dim
- self.norm = norm_layer(dim_norm) if norm_layer else None
-
- self.pre_norm = pre_norm
-
- def forward(self, x, size):
- H, W = size
- if len(x.size()) == 3:
- if self.norm and self.pre_norm:
- x = self.norm(x)
- x = rearrange(x, 'b (h w) c -> b c h w', h=H, w=W)
-
- x = self.proj(x)
-
- _, _, H, W = x.shape
- x = rearrange(x, 'b c h w -> b (h w) c')
- if self.norm and not self.pre_norm:
- x = self.norm(x)
-
- return x, (H, W)
-
-
-class ChannelAttention(nn.Module):
-
- def __init__(self, dim, groups=8, qkv_bias=True):
- super().__init__()
-
- self.groups = groups
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
- self.proj = nn.Linear(dim, dim)
-
- def forward(self, x, size):
- B, N, C = x.shape
-
- qkv = self.qkv(x).reshape(B, N, 3, self.groups,
- C // self.groups).permute(2, 0, 3, 1, 4)
- q, k, v = qkv[0], qkv[1], qkv[2]
-
- q = q * (float(N)**-0.5)
- attention = q.transpose(-1, -2) @ k
- attention = attention.softmax(dim=-1)
- x = (attention @ v.transpose(-1, -2)).transpose(-1, -2)
- x = x.transpose(1, 2).reshape(B, N, C)
- x = self.proj(x)
- return x, size
-
-
-class ChannelBlock(nn.Module):
-
- def __init__(self,
- dim,
- groups,
- mlp_ratio=4.,
- qkv_bias=True,
- drop_path_rate=0.,
- act_layer=nn.GELU,
- norm_layer=nn.LayerNorm,
- conv_at_attn=True,
- conv_at_ffn=True):
- super().__init__()
-
- self.conv1 = PreNorm(None, DepthWiseConv2d(
- dim, 3, 1, 1)) if conv_at_attn else None
- self.channel_attn = PreNorm(
- norm_layer(dim),
- ChannelAttention(dim, groups=groups, qkv_bias=qkv_bias),
- )
- self.conv2 = PreNorm(None, DepthWiseConv2d(dim, 3, 1,
- 1)) if conv_at_ffn else None
- self.ffn = PreNorm(
- norm_layer(dim),
- Mlp(in_features=dim,
- hidden_features=int(dim * mlp_ratio),
- act_layer=act_layer),
- )
-
- def forward(self, x, size):
- if self.conv1:
- x, size = self.conv1(x, size)
- x, size = self.channel_attn(x, size)
-
- if self.conv2:
- x, size = self.conv2(x, size)
- x, size = self.ffn(x, size)
-
- return x, size
-
-
-def window_partition(x, window_size: int):
- B, H, W, C = x.shape
- x = x.view(B, H // window_size, window_size, W // window_size, window_size,
- C)
- windows = x.permute(0, 1, 3, 2, 4,
- 5).contiguous().view(-1, window_size, window_size, C)
- return windows
-
-
-def window_reverse(windows, batch_size: int, window_size: int, H: int, W: int):
- B = batch_size
-
- x = windows.view(B, H // window_size, W // window_size, window_size,
- window_size, -1)
- x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
- return x
-
-
-class WindowAttention(nn.Module):
-
- def __init__(self, dim, num_heads, window_size, qkv_bias=True):
-
- super().__init__()
- self.dim = dim
- self.window_size = window_size
- self.num_heads = num_heads
- head_dim = dim // num_heads
- self.scale = float(head_dim)**-0.5
-
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
- self.proj = nn.Linear(dim, dim)
-
- self.softmax = nn.Softmax(dim=-1)
-
- def forward(self, x, size):
-
- H, W = size
- B, L, C = x.shape
- assert L == H * W, "input feature has wrong size"
-
- x = x.view(B, H, W, C)
-
- pad_l = pad_t = 0
- pad_r = (self.window_size - W % self.window_size) % self.window_size
- pad_b = (self.window_size - H % self.window_size) % self.window_size
- x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
- _, Hp, Wp, _ = x.shape
-
- x = window_partition(x, self.window_size)
- x = x.view(-1, self.window_size * self.window_size, C)
-
- # W-MSA/SW-MSA
- # attn_windows = self.attn(x_windows)
-
- B_, N, C = x.shape
- qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads,
- C // self.num_heads).permute(2, 0, 3, 1, 4)
- q, k, v = qkv[0], qkv[1], qkv[2]
-
- q = q * self.scale
- attn = (q @ k.transpose(-2, -1))
- attn = self.softmax(attn)
-
- x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
- x = self.proj(x)
-
- # merge windows
- x = x.view(-1, self.window_size, self.window_size, C)
- x = window_reverse(x, B, self.window_size, Hp, Wp)
-
- if pad_r > 0 or pad_b > 0:
- x = x[:, :H, :W, :].contiguous()
-
- x = x.view(B, H * W, C)
-
- return x, size
-
-
-class SpatialBlock(nn.Module):
-
- def __init__(self,
- dim,
- num_heads,
- window_size,
- mlp_ratio=4.,
- qkv_bias=True,
- drop_path_rate=0.,
- act_layer=nn.GELU,
- norm_layer=nn.LayerNorm,
- conv_at_attn=True,
- conv_at_ffn=True):
- super().__init__()
-
- self.conv1 = PreNorm(None, DepthWiseConv2d(
- dim, 3, 1, 1)) if conv_at_attn else None
- self.window_attn = PreNorm(
- norm_layer(dim),
- WindowAttention(dim, num_heads, window_size, qkv_bias=qkv_bias),
- )
- self.conv2 = PreNorm(None, DepthWiseConv2d(dim, 3, 1,
- 1)) if conv_at_ffn else None
- self.ffn = PreNorm(
- norm_layer(dim),
- Mlp(in_features=dim,
- hidden_features=int(dim * mlp_ratio),
- act_layer=act_layer),
- )
-
- def forward(self, x, size):
- if self.conv1:
- x, size = self.conv1(x, size)
- x, size = self.window_attn(x, size)
-
- if self.conv2:
- x, size = self.conv2(x, size)
- x, size = self.ffn(x, size)
- return x, size
-
-
-class DaViT(nn.Module):
-
- def __init__(
- self,
- in_chans=3,
- num_classes=1000,
- depths=(1, 1, 3, 1),
- patch_size=(7, 2, 2, 2),
- patch_stride=(4, 2, 2, 2),
- patch_padding=(3, 0, 0, 0),
- patch_prenorm=(False, False, False, False),
- embed_dims=(64, 128, 192, 256),
- num_heads=(3, 6, 12, 24),
- num_groups=(3, 6, 12, 24),
- window_size=7,
- mlp_ratio=4.,
- qkv_bias=True,
- drop_path_rate=0.1,
- norm_layer=nn.LayerNorm,
- enable_checkpoint=False,
- conv_at_attn=True,
- conv_at_ffn=True,
- ):
- super().__init__()
-
- self.num_classes = num_classes
- self.embed_dims = embed_dims
- self.num_heads = num_heads
- self.num_groups = num_groups
- self.num_stages = len(self.embed_dims)
- self.enable_checkpoint = enable_checkpoint
- assert self.num_stages == len(self.num_heads) == len(self.num_groups)
-
- num_stages = len(embed_dims)
- dpr = [
- x.item() for x in torch.linspace(0, drop_path_rate,
- sum(depths) * 2)
- ]
-
- depth_offset = 0
- convs = []
- blocks = []
- for i in range(num_stages):
- conv_embed = ConvEmbed(
- patch_size=patch_size[i],
- stride=patch_stride[i],
- padding=patch_padding[i],
- in_chans=in_chans if i == 0 else self.embed_dims[i - 1],
- embed_dim=self.embed_dims[i],
- norm_layer=norm_layer,
- pre_norm=patch_prenorm[i])
- convs.append(conv_embed)
-
- block = MySequential(*[
- MySequential(
- OrderedDict([('spatial_block',
- SpatialBlock(
- embed_dims[i],
- num_heads[i],
- window_size,
- drop_path_rate=dpr[depth_offset + j * 2],
- qkv_bias=qkv_bias,
- mlp_ratio=mlp_ratio,
- conv_at_attn=conv_at_attn,
- conv_at_ffn=conv_at_ffn,
- )),
- ('channel_block',
- ChannelBlock(
- embed_dims[i],
- num_groups[i],
- drop_path_rate=dpr[depth_offset + j * 2 +
- 1],
- qkv_bias=qkv_bias,
- mlp_ratio=mlp_ratio,
- conv_at_attn=conv_at_attn,
- conv_at_ffn=conv_at_ffn,
- ))])) for j in range(depths[i])
- ])
- blocks.append(block)
- depth_offset += depths[i] * 2
-
- self.convs = nn.ModuleList(convs)
- self.blocks = nn.ModuleList(blocks)
-
- self.avgpool = nn.AdaptiveAvgPool1d(1)
-
- @property
- def dim_out(self):
- return self.embed_dims[-1]
-
- def forward_features_unpool(self, x):
- """
- forward until avg pooling
- Args:
- x (_type_): input image tensor
- """
- input_size = (x.size(2), x.size(3))
- for conv, block in zip(self.convs, self.blocks):
- x, input_size = conv(x, input_size)
- x, input_size = block(x, input_size)
- return x
-
- def forward_features(self, x):
- x = self.forward_features_unpool(x)
-
- # (batch_size, num_tokens, token_dim)
- x = self.avgpool(x.transpose(1, 2))
- # (batch_size, 1, num_tokens)
- x = torch.flatten(x, 1)
- x = self.norms(x)
-
- return x
-
- def forward(self, x):
- x = self.forward_features(x)
- x = self.head(x)
- return x
-
- @classmethod
- def from_config(cls, config):
- return cls(
- depths=config.depths,
- embed_dims=config.dim_embed,
- num_heads=config.num_heads,
- num_groups=config.num_groups,
- patch_size=config.patch_size,
- patch_stride=config.patch_stride,
- patch_padding=config.patch_padding,
- patch_prenorm=config.patch_prenorm,
- drop_path_rate=config.drop_path_rate,
- window_size=config.window_size,
- )
-
-
-# Language backbone and processor implementation
class Florence2LanguageModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
@@ -624,14 +73,18 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.vocab_size = config.vocab_size
self.shared = BartScaledWordEmbedding(self.vocab_size, config.d_model)
- self.encoder = BartEncoder(config,
- cache_config=cache_config,
- quant_config=quant_config,
- prefix=f"{prefix}.encoder")
- self.decoder = BartDecoder(config,
- cache_config=cache_config,
- quant_config=quant_config,
- prefix=f"{prefix}.decoder")
+ self.encoder = BartEncoder(
+ config,
+ cache_config=cache_config,
+ quant_config=quant_config,
+ prefix=f"{prefix}.encoder",
+ )
+ self.decoder = BartDecoder(
+ config,
+ cache_config=cache_config,
+ quant_config=quant_config,
+ prefix=f"{prefix}.decoder",
+ )
if self.config.tie_word_embeddings:
self.encoder.embed_tokens.weight = self.shared.weight
@@ -664,18 +117,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
self.config = config
- self.model = Florence2LanguageModel(vllm_config=vllm_config,
- prefix=f"{prefix}.model")
- embed_scale = math.sqrt(
- config.d_model) if config.scale_embedding else 1.0
+ self.model = Florence2LanguageModel(
+ vllm_config=vllm_config, prefix=f"{prefix}.model"
+ )
+ embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
self.vocab_size = config.vocab_size
- self.lm_head = BartParallelLMHead(self.vocab_size,
- config.d_model,
- embed_scale=embed_scale)
+ self.lm_head = BartParallelLMHead(
+ self.vocab_size, config.d_model, embed_scale=embed_scale
+ )
- self.logits_processor = LogitsProcessor(self.vocab_size,
- config.vocab_size)
+ self.logits_processor = LogitsProcessor(self.vocab_size, config.vocab_size)
if self.config.tie_word_embeddings:
self.lm_head.tie_weights(self.model.shared)
@@ -689,10 +141,12 @@ def forward(
# num_encoder_outputs: int | None = None,
**kwargs,
) -> torch.Tensor:
- return self.model(input_ids,
- positions,
- inputs_embeds=inputs_embeds,
- encoder_outputs=encoder_outputs)
+ return self.model(
+ input_ids,
+ positions,
+ inputs_embeds=inputs_embeds,
+ encoder_outputs=encoder_outputs,
+ )
def get_encoder_outputs(
self,
@@ -719,8 +173,7 @@ def compute_logits(
logits = self.logits_processor(self.lm_head, hidden_states)
return logits
- def load_weights(self, weights: Iterable[tuple[str,
- torch.Tensor]]) -> set[str]:
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("encoder_attn.kv_proj", "encoder_attn.k_proj", "k"),
@@ -733,7 +186,7 @@ def load_weights(self, weights: Iterable[tuple[str,
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
- for (param_name, weight_name, shard_id) in stacked_params_mapping:
+ for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
@@ -747,8 +200,7 @@ def load_weights(self, weights: Iterable[tuple[str,
if self.config.tie_word_embeddings and "embed_tokens" in name:
continue
param = params_dict[name]
- weight_loader = getattr(param, "weight_loader",
- default_weight_loader)
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
@@ -756,22 +208,21 @@ def load_weights(self, weights: Iterable[tuple[str,
class Florence2ProcessingInfo(BaseProcessingInfo):
- def get_hf_config(self):
+ def get_hf_config(self) -> Florence2Config:
return self.ctx.get_hf_config()
- def get_hf_processor(self):
+ def get_hf_processor(self) -> Florence2Processor:
return self.ctx.get_hf_processor()
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"image": 1}
def get_num_image_tokens(self) -> int:
- processor_config = self.ctx.get_hf_image_processor_config()
- return processor_config["image_seq_length"]
+ processor = self.get_hf_processor()
+ return processor.num_image_tokens
-class Florence2DummyInputsBuilder(
- BaseDummyInputsBuilder[Florence2ProcessingInfo]):
+class Florence2DummyInputsBuilder(BaseDummyInputsBuilder[Florence2ProcessingInfo]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
return ""
@@ -784,18 +235,31 @@ def get_dummy_mm_data(
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
- target_width = target_height = self.info.get_hf_config().projection_dim
+ target_width = target_height = (
+ self.info.get_hf_config().vision_config.projection_dim
+ )
return {
- "image":
- self._get_dummy_images(width=target_width,
- height=target_height,
- num_images=num_images)
+ "image": self._get_dummy_images(
+ width=target_width, height=target_height, num_images=num_images
+ )
}
-class Florence2MultiModalProcessor(
- EncDecMultiModalProcessor[Florence2ProcessingInfo]):
+class Florence2MultiModalProcessor(EncDecMultiModalProcessor[Florence2ProcessingInfo]):
+
+ def __init__(self, info, dummy_inputs, *, cache=None) -> None:
+ super().__init__(info, dummy_inputs, cache=cache)
+ # Florence2Config does not expose decoder_start_token_id at the
+ # top level (it lives in text_config), so vLLM falls back to BOS
+ # (token 0) and incorrectly prepends it to the decoder prompt.
+ # Patch the top-level hf_config so vLLM's _prepare_decoder_input_ids
+ # sees the real value (EOS / token 2) and leaves our prompt intact.
+ hf_config = info.get_hf_config()
+ if getattr(hf_config, "decoder_start_token_id", None) is None:
+ hf_config.decoder_start_token_id = (
+ hf_config.text_config.decoder_start_token_id
+ )
def _hf_processor_applies_updates(
self,
@@ -804,7 +268,10 @@ def _hf_processor_applies_updates(
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
) -> bool:
- return False
+ # The Florence2Processor already inserts image_token_id placeholders
+ # into the input_ids (577 tokens for a 768x768 image), so we tell
+ # vllm to find those existing placeholders rather than insert new ones.
+ return bool(mm_items.get_all_counts().get("image", 0))
def create_encoder_prompt(
self,
@@ -818,7 +285,16 @@ def create_decoder_prompt(
prompt: str | list[int],
mm_data: MultiModalDataDict,
) -> str | list[int]:
- return [self.info.get_hf_config().eos_token_id]
+ text_config = self.info.get_hf_config().text_config
+ # Decoder prompt mirrors what transformers does before open-ended
+ # generation: start with decoder_start_token_id (, token 2),
+ # then include forced_bos_token_id (, token 0) so that vLLM
+ # generates from the same position as transformers step 2.
+ decoder_prompt = [text_config.decoder_start_token_id]
+ forced_bos = getattr(text_config, "forced_bos_token_id", None)
+ if forced_bos is not None:
+ decoder_prompt.append(forced_bos)
+ return decoder_prompt
def _apply_hf_processor_tokens_only(
self,
@@ -841,14 +317,15 @@ def _call_hf_processor(
) -> BatchFeature:
if mm_data:
processed_outputs = super()._call_hf_processor(
- prompt, mm_data, mm_kwargs, tok_kwargs)
+ prompt, mm_data, mm_kwargs, tok_kwargs
+ )
else:
hf_processor = self.info.get_hf_processor()
tokenizer = hf_processor.tokenizer
prompt = hf_processor._construct_prompts([prompt])[0]
- processed_outputs = tokenizer(prompt,
- add_special_tokens=True,
- return_tensors="pt")
+ processed_outputs = tokenizer(
+ prompt, add_special_tokens=True, return_tensors="pt"
+ )
processed_outputs["encoder_input_ids"] = processed_outputs["input_ids"]
return processed_outputs
@@ -868,16 +345,38 @@ def _get_prompt_updates(
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
- hf_config = self.info.get_hf_config()
- pad_token_id = hf_config.pad_token_id
- num_image_tokens = self.info.get_num_image_tokens()
- image_tokens = [pad_token_id] * num_image_tokens
+ # The placeholder must cover the FULL encoder input sequence (image
+ # tokens + text/task tokens) so that vLLM's _get_encoder_seq_lens
+ # computes the correct value for cross-attention KV cache allocation.
+ # Using only the image token count (577) would cause cross-attention
+ # to read only 577/590 K/V pairs, skipping the task-prompt tokens.
+ #
+ # With _hf_processor_applies_updates=True, vLLM detects the existing
+ # token sequence rather than inserting new tokens. By setting the
+ # insertion to the full encoder_input_ids sequence, the detected
+ # placeholder range covers all 590 encoder tokens.
+ insertion: list[int]
+ image_items = out_mm_kwargs.get("image", [])
+ if image_items:
+ item_data = image_items[0].get_data()
+ enc_ids = item_data.get("encoder_input_ids")
+ if enc_ids is not None:
+ insertion = enc_ids.tolist()
+ else:
+ # Cache hit: encoder_input_ids not available; fall back.
+ hf_config = self.info.get_hf_config()
+ insertion = [
+ hf_config.image_token_id
+ ] * self.info.get_num_image_tokens()
+ else:
+ hf_config = self.info.get_hf_config()
+ insertion = [hf_config.image_token_id] * self.info.get_num_image_tokens()
return [
PromptInsertion(
modality="image",
target=PromptIndexTargets.start(),
- insertion=image_tokens,
+ insertion=insertion,
)
]
@@ -885,8 +384,17 @@ def _get_prompt_updates(
@MULTIMODAL_REGISTRY.register_processor(
Florence2MultiModalProcessor,
info=Florence2ProcessingInfo,
- dummy_inputs=Florence2DummyInputsBuilder)
+ dummy_inputs=Florence2DummyInputsBuilder,
+)
class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal):
+ hf_to_vllm_mapper = WeightsMapper(
+ orig_to_new_prefix={
+ "model.vision_tower.": "vision_tower.",
+ "model.multi_modal_projector.": "multi_modal_projector.",
+ "model.language_model.": "language_model.model.",
+ "lm_head.": "language_model.lm_head.",
+ }
+ )
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
@@ -901,63 +409,39 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
processor_config = vllm_config.model_config.hf_image_processor_config
self.config = config
- self.vision_config = config.vision_config
self.processor_config = processor_config
- assert config.vision_config.model_type == 'davit', (
- 'only DaViT is supported for now')
- self.vision_tower = DaViT.from_config(config=config.vision_config)
- self._build_image_projection_layers(config)
+ assert config.vision_config.model_type == "florence_vision", (
+ f"only Florence Vision is supported for now. "
+ f"Received model type: {config.vision_config.model_type}"
+ )
+ self.vision_tower = Florence2VisionBackbone(config.vision_config)
+ self.multi_modal_projector = Florence2MultiModalProjector(config)
self.language_model = Florence2LanguageForConditionalGeneration(
vllm_config=vllm_config.with_hf_config(config.text_config),
prefix=f"{prefix}.language_model",
)
- self.pad_token_id = config.pad_token_id
-
- def _build_image_projection_layers(self, config: PretrainedConfig):
- image_dim_out = config.vision_config.dim_embed[-1]
- dim_projection = config.vision_config.projection_dim
- self.image_projection = nn.Parameter(
- torch.empty(image_dim_out, dim_projection))
- self.image_proj_norm = nn.LayerNorm(dim_projection)
- image_pos_embed_config = config.vision_config.image_pos_embed
- if image_pos_embed_config['type'] == 'learned_abs_2d':
- self.image_pos_embed = LearnedAbsolutePositionEmbedding2D(
- embedding_dim=image_dim_out,
- num_pos=image_pos_embed_config['max_pos_embeddings'])
- else:
- raise NotImplementedError("Florence2 only supports learned_abs_2d "
- "as image position embedding.")
-
- self.image_feature_source = config.vision_config.image_feature_source
-
- # temporal embedding
- visual_temporal_embedding_config = (
- self.vision_config.visual_temporal_embedding)
- if visual_temporal_embedding_config['type'] == 'COSINE':
- self.visual_temporal_embed = PositionalEmbeddingCosine1D(
- embed_dim=image_dim_out,
- max_seq_len=visual_temporal_embedding_config[
- 'max_temporal_embeddings'])
- else:
- raise NotImplementedError(
- 'Florence2 only supports COSINE as temporal embedding.')
+ self.pad_token_id = config.text_config.pad_token_id
def _validate_pixel_values(
self, data: torch.Tensor | list[torch.Tensor]
) -> torch.Tensor | list[torch.Tensor]:
+ # The image processor config may use "size" or "crop_size"; fall back
+ # to reading the actual tensor shape if neither key is available.
+ cfg = self.processor_config
+ size = cfg.get("size") or cfg.get("crop_size")
+ if size is None:
+ return data
- size = self.processor_config["size"]
h, w = size["height"], size["width"]
expected_dims = (3, h, w)
def _validate_shape(d: torch.Tensor):
actual_dims = tuple(d.shape)
-
if actual_dims != expected_dims:
- expected_expr = tuple(*map(str, expected_dims))
raise ValueError(
"The expected shape of pixel values per batch "
- f"is {expected_expr}. You supplied {tuple(d.shape)}.")
+ f"is {expected_dims}. You supplied {actual_dims}."
+ )
for d in data:
_validate_shape(d)
@@ -965,112 +449,43 @@ def _validate_shape(d: torch.Tensor):
return data
def _parse_and_validate_image_input(self, **kwargs: object):
- pixel_values: list[list[torch.Tensor]] | list[torch.Tensor] | torch.Tensor | None = kwargs.pop(
- "pixel_values", None)
- image_embeds: list[list[torch.Tensor]] | list[torch.Tensor] | torch.Tensor | None = kwargs.pop(
- "image_embeds", None)
+ pixel_values = kwargs.pop("pixel_values", None)
+ image_embeds = kwargs.pop("image_embeds", None)
if pixel_values is None and image_embeds is None:
return None
-
if pixel_values is not None and image_embeds is not None:
- raise ValueError(
- "Both pixel values and image embeds are provided.")
-
+ raise ValueError("Both pixel values and image embeds are provided.")
if pixel_values is not None:
return Florence2ImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
)
-
- if image_embeds is not None:
- raise NotImplementedError
-
- raise AssertionError("This line should be unreachable.")
+ raise NotImplementedError("image_embeds not supported.")
def _parse_and_validate_encoder_input(self, **kwargs: object) -> list[torch.Tensor]:
encoder_input_ids = kwargs.get("encoder_input_ids", kwargs.get("input_ids"))
-
if encoder_input_ids is None:
return []
-
if not isinstance(encoder_input_ids, (torch.Tensor, list)):
raise ValueError(
- "Incorrect type of encoder input_ids. "
- f"Got type: {type(encoder_input_ids)}"
+ f"Incorrect type of encoder input_ids. Got type: {type(encoder_input_ids)}"
)
-
- # Return as a list of tensors (one per item in the batch)
if isinstance(encoder_input_ids, list):
- # Already a list - ensure each item is valid
- result = []
- for item in encoder_input_ids:
- if isinstance(item, torch.Tensor):
- if item.dim() == 0:
- item = item.unsqueeze(0)
- result.append(item)
- else:
- result.append(item)
- return result
- else:
- # [1xD]xN times
- return encoder_input_ids.unsqueeze(1).unbind(dim=0)
+ return [
+ item.unsqueeze(0) if item.dim() == 0 else item
+ for item in encoder_input_ids
+ ]
+ return encoder_input_ids.unsqueeze(1).unbind(dim=0)
def _encode_image(self, pixel_values: torch.Tensor) -> torch.Tensor:
- dtype = next(self.vision_tower.parameters()).dtype
- pixel_values = pixel_values.to(dtype)
-
- batch_size, T = pixel_values.size(0), 1
- x = self.vision_tower.forward_features_unpool(pixel_values)
- if self.image_pos_embed is not None:
- x = x.view(batch_size * T, -1, x.shape[-1])
- num_tokens = x.shape[-2]
- h, w = int(num_tokens**0.5), int(num_tokens**0.5)
- assert h * w == num_tokens, (
- 'only support square feature maps for now')
- x = x.view(batch_size * T, h, w, x.shape[-1])
- pos_embed = self.image_pos_embed(x)
- x = x + pos_embed
- x = x.view(batch_size, T * h * w, x.shape[-1])
-
- if self.visual_temporal_embed is not None:
- visual_temporal_embed = self.visual_temporal_embed(
- x.view(batch_size, T, -1, x.shape[-1])[:, :, 0])
- x = x.view(batch_size, T, -1,
- x.shape[-1]) + visual_temporal_embed.view(
- 1, T, 1, x.shape[-1])
-
- x_feat_dict = {}
-
- spatial_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=2)
- x_feat_dict['spatial_avg_pool'] = spatial_avg_pool_x
-
- temporal_avg_pool_x = x.view(batch_size, T, -1,
- x.shape[-1]).mean(dim=1)
- x_feat_dict['temporal_avg_pool'] = temporal_avg_pool_x
-
- x = x.view(batch_size, T, -1, x.shape[-1])[:, -1]
- x_feat_dict['last_frame'] = x
-
- new_x = []
- for _image_feature_source in self.image_feature_source:
- if _image_feature_source not in x_feat_dict:
- raise ValueError('invalid image feature source: {}'.format(
- _image_feature_source))
- new_x.append(x_feat_dict[_image_feature_source])
-
- x = torch.cat(new_x, dim=1)
-
- x = x @ self.image_projection
- x = self.image_proj_norm(x)
-
- return x
+ pixel_values = pixel_values.to(next(self.vision_tower.parameters()).dtype)
+ return self.multi_modal_projector(self.vision_tower(pixel_values))
def _process_image_input(
- self, image_input: Florence2ImagePixelInputs) -> torch.Tensor:
- assert image_input["type"] == "pixel_values"
- pixel_values = image_input["data"]
- return self._encode_image(pixel_values)
+ self, image_input: Florence2ImagePixelInputs
+ ) -> torch.Tensor:
+ return self._encode_image(image_input["data"])
def get_language_model(self) -> torch.nn.Module:
return self.language_model
@@ -1085,14 +500,10 @@ def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
if not encoder_input_ids_list:
raise ValueError(
- "encoder_input_ids_list is empty - this should not happen. "
- "Check that multimodal data is being passed correctly."
+ "encoder_input_ids_list is empty - check multimodal data is being passed correctly."
)
- # Process each encoder input separately and return a list of outputs
- # NOTE (NickLucche): Basic encoder batching optimization: BART input sequences
- # can have different lengths. Due to computational load of encoder being very
- # low here, we batch all sequences to run a single forward by max_seq padding.
+ # Batch encoder inputs (pad to max length if needed) and run a single forward pass.
lengths = [t.numel() for t in encoder_input_ids_list]
max_len = max(lengths) if lengths else 0
assert max_len > 0, "Empty encoder_input_ids encountered."
@@ -1100,45 +511,47 @@ def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
if len(encoder_input_ids_list) == 1:
batch_encoder_input_ids = encoder_input_ids_list[0]
elif same_len:
- # [1xD]xN =>NxD
batch_encoder_input_ids = torch.cat(encoder_input_ids_list, dim=0)
else:
batch_encoder_input_ids = torch.full(
(len(encoder_input_ids_list), max_len),
- fill_value=self._pad_id,
+ fill_value=self.pad_token_id,
dtype=encoder_input_ids_list[0].dtype,
device=encoder_input_ids_list[0].device,
)
for i, t in enumerate(encoder_input_ids_list):
batch_encoder_input_ids[i, : t.numel()] = t.squeeze()
- # Create (B, T) positions: 0..T-1 for each item.
- # batch_encoder_positions = torch.arange(
- # max_len,
- # dtype=torch.long,
- # device=batch_encoder_input_ids.device,
- # ).unsqueeze(0).expand(batch_encoder_input_ids.size(0), -1)
-
- inputs_embeds = self.language_model.model.encoder.embed_tokens(batch_encoder_input_ids)
- inputs_embeds = torch.cat([vision_embeddings, inputs_embeds], dim=-2)
- batch_encoder_positions = torch.arange(
- inputs_embeds.size(1),
- dtype=torch.long,
- device=inputs_embeds.device,
- ).unsqueeze(0).expand(inputs_embeds.size(0), -1)
-
- # Run encoder once on the batch.
+ inputs_embeds = self.language_model.model.encoder.embed_tokens(
+ batch_encoder_input_ids
+ )
+
+ # Replace the leading image_token_id placeholders with vision features.
+ if (
+ isinstance(vision_embeddings, torch.Tensor)
+ and vision_embeddings.numel() > 0
+ ):
+ num_vision = vision_embeddings.size(1)
+ inputs_embeds = inputs_embeds.clone()
+ inputs_embeds[:, :num_vision, :] = vision_embeddings
+ batch_encoder_positions = (
+ torch.arange(
+ inputs_embeds.size(1),
+ dtype=torch.long,
+ device=inputs_embeds.device,
+ )
+ .unsqueeze(0)
+ .expand(inputs_embeds.size(0), -1)
+ )
+
+ # Run encoder once on the batch, then split back per item.
batch_encoder_output = self.language_model.model.encoder(
input_ids=batch_encoder_input_ids,
positions=batch_encoder_positions,
inputs_embeds=inputs_embeds,
)
- # Split back into list[(T, H)] to match expected downstream format.
- # If we had to pad, slice back to the original lengths per item.
encoder_outputs: list[torch.Tensor] = batch_encoder_output.unbind(dim=0)
if not same_len:
- encoder_outputs = [
- out[:l] for out, l in zip(encoder_outputs, lengths)
- ]
+ encoder_outputs = [out[:l] for out, l in zip(encoder_outputs, lengths)]
return encoder_outputs
def forward(
@@ -1148,30 +561,18 @@ def forward(
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
encoder_outputs: torch.Tensor | None = None,
- # num_encoder_outputs: int | None = None,
**kwargs,
) -> torch.Tensor:
- r"""
- Args:
- input_ids
- torch.Tensor of *decoder* input token ids.
- positions
- torch.Tensor of *decoder* position indices.
- encoder_input_ids
- torch.Tensor of *encoder* input token ids.
- encoder_positions
- torch.Tensor of *encoder* position indices
- Returns:
- Output torch.Tensor
- """
if encoder_outputs is not None:
# Assume same shape for all encoder outputs
encoder_outputs = torch.cat(encoder_outputs, dim=0)
- hidden_states = self.language_model(input_ids,
- positions,
- encoder_outputs=encoder_outputs,
- inputs_embeds=inputs_embeds)
+ hidden_states = self.language_model(
+ input_ids,
+ positions,
+ encoder_outputs=encoder_outputs,
+ inputs_embeds=inputs_embeds,
+ )
return hidden_states
def compute_logits(
@@ -1180,7 +581,11 @@ def compute_logits(
) -> torch.Tensor | None:
return self.language_model.compute_logits(hidden_states)
- def load_weights(self, weights: Iterable[tuple[str,
- torch.Tensor]]) -> set[str]:
- loader = AutoWeightsLoader(self)
- return loader.load_weights(weights)
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
+ # pos_idx_to_embed is a register_buffer in the transformers implementation
+ # (deterministically computed from config), so it has no matching parameter.
+ loader = AutoWeightsLoader(
+ self,
+ ignore_unexpected_suffixes=["visual_temporal_embed.pos_idx_to_embed"],
+ )
+ return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)