diff --git a/README.md b/README.md index 60efa88..6e0c11a 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,13 @@ This plugin requires [uv](https://docs.astral.sh/uv/) for package management. If ```bash curl -LsSf https://astral.sh/uv/install.sh | sh ``` +### From Git + +Install from git: + +```bash +pip install git+https://github.com/vllm-project/bart-plugin +``` ### From Source @@ -126,10 +133,8 @@ This plugin should work with any BART-based model from HuggingFace, including: ### Florence-2 Models -- `microsoft/Florence-2-base` -- `microsoft/Florence-2-large` - -Note: Florence-2 requires `trust_remote_code=True` and uses a separate tokenizer (`Isotr0py/Florence-2-tokenizer`). +- `florence-community/Florence-2-base` +- `florence-community/Florence-2-large` ## Evaluation @@ -186,11 +191,14 @@ Notes: ``` bart-plugin/ ├── vllm_bart_plugin/ -│ ├── __init__.py # Plugin registration -│ └── bart.py # BART model implementation -├── setup.py # Package configuration and entry points -├── README.md # This file -└── LICENSE # License file +│ ├── __init__.py # Plugin registration +│ └── bart.py # BART model implementation +│ └── florence2.py # Florence-2 model implementation +├── setup.py # Package configuration and entry points +├── README.md # This file +└── LICENSE # License file +└── example_bart_usage.py # Example usage script for BART +└── example_florence2_usage.py # Example usage script for Florence-2 ``` ### Running Tests diff --git a/example_florence2_usage.py b/example_florence2_usage.py index 6e9593b..32c9c96 100644 --- a/example_florence2_usage.py +++ b/example_florence2_usage.py @@ -5,28 +5,23 @@ This script demonstrates how to use Florence-2 models with vLLM after installing the BART plugin. """ -import vllm_bart_plugin + from vllm import LLM, SamplingParams from vllm.assets.image import ImageAsset def main(): """Run Florence-2 model examples.""" - model_name = "microsoft/Florence-2-large" - tokenizer_name = "Isotr0py/Florence-2-tokenizer" + model_name = "florence-community/Florence-2-large-ft" llm = LLM( model=model_name, - tokenizer=tokenizer_name, mm_processor_cache_gb=0, - trust_remote_code=True, enforce_eager=True, ) params = SamplingParams( temperature=0.0, max_tokens=20, - # repetition_penalty is needed to prevent repetition - repetition_penalty=1.5, # skip_special_tokens=False is needed to present # grounding tokens like skip_special_tokens=False, @@ -60,4 +55,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/pyproject.toml b/pyproject.toml index c2627f2..e1789ea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "vllm-bart-plugin" -version = "0.2.0" +version = "0.3.0" description = "BART model plugin for vLLM" readme = "README.md" requires-python = ">=3.10" @@ -26,9 +26,9 @@ classifiers = [ ] dependencies = [ - "vllm>=0.14.0", + "vllm>=0.16.0", "torch>=2.9.0", - "transformers>=4.56.0,<5", + "transformers>=4.56.0,<6", ] [project.optional-dependencies] @@ -62,8 +62,20 @@ include = '\.pyi?$' profile = "black" line_length = 88 +[tool.pytest.ini_options] +markers = [ + "slow: marks tests requiring a GPU and full model download (deselect with '-m \"not slow\"')", +] + [tool.mypy] python_version = "3.10" warn_return_any = true warn_unused_configs = true ignore_missing_imports = true + +[dependency-groups] +dev = [ + "black>=26.1.0", + "isort>=8.0.1", + "pytest>=9.0.2", +] diff --git a/tests/test_florence2.py b/tests/test_florence2.py new file mode 100644 index 0000000..38e5731 --- /dev/null +++ b/tests/test_florence2.py @@ -0,0 +1,267 @@ +"""Tests for the Florence-2 multimodal model plugin.""" + +import os + +import pytest +import torch +from transformers import Florence2Config + +MODEL_NAME = "florence-community/Florence-2-base-ft" + + +def _small_vision_config(): + """Tiny 1-stage Florence2 config for fast CPU tests.""" + cfg = Florence2Config() + vc = cfg.vision_config + vc.embed_dim = [64] + vc.depths = [1] + vc.num_heads = [4] + vc.num_groups = [4] + vc.patch_size = [7] + vc.patch_stride = [4] + vc.patch_padding = [3] + vc.patch_prenorm = [False] + vc.drop_path_rate = 0.0 + return cfg, vc + + +# --------------------------------------------------------------------------- +# Unit tests — vision architecture (CPU, no weights) +# --------------------------------------------------------------------------- + + +class TestFlorenceVisionBackbone: + def test_output_shape(self): + from vllm_bart_plugin.florence2 import Florence2VisionBackbone + + _, vc = _small_vision_config() + out = Florence2VisionBackbone(vc)(torch.randn(2, 3, 64, 64)) + assert out.shape == (2, vc.embed_dim[-1], 16, 16) + + +class TestFlorenceMultiModalProjector: + def test_output_shape(self): + from vllm_bart_plugin.florence2 import Florence2MultiModalProjector + + cfg, vc = _small_vision_config() + vc.projection_dim = 128 + m = Florence2MultiModalProjector(cfg) + out = m(torch.randn(2, vc.embed_dim[-1], 12, 12)) + # (B, 1 spatial-avg token + H*W tokens, proj_dim) + assert out.shape == (2, 1 + 12 * 12, vc.projection_dim) + + +# --------------------------------------------------------------------------- +# Integration tests — full model inference (GPU required) +# --------------------------------------------------------------------------- + + +def _run_task(llm, processor, image, task_prompt, text_input=None, max_tokens=100): + """Helper: run one Florence-2 task and return the post-processed result.""" + from vllm import SamplingParams + + prompt = task_prompt if text_input is None else task_prompt + text_input + params = SamplingParams( + temperature=0.0, max_tokens=max_tokens, skip_special_tokens=False + ) + outputs = llm.generate( + [{"prompt": prompt, "multi_modal_data": {"image": image}}], + sampling_params=params, + ) + raw = outputs[0].outputs[0].text + return processor.post_process_generation( + raw, task=task_prompt, image_size=image.size + ) + + +@pytest.fixture(scope="module") +def florence2_llm(): + from vllm import LLM + + os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" + return LLM( + model=MODEL_NAME, + enforce_eager=True, + gpu_memory_utilization=0.5, + mm_processor_cache_gb=0, + ) + + +@pytest.fixture(scope="module") +def florence2_processor(): + from transformers import AutoProcessor + + return AutoProcessor.from_pretrained(MODEL_NAME) + + +@pytest.fixture(scope="module") +def stop_sign_image(): + from vllm.assets.image import ImageAsset + + return ImageAsset("stop_sign").pil_image.convert("RGB") + + +@pytest.mark.slow +class TestFlorenceInference: + # ------------------------------------------------------------------ + # Caption tasks — check for semantically meaningful keywords + # ------------------------------------------------------------------ + + def test_caption(self, florence2_llm, florence2_processor, stop_sign_image): + result = _run_task( + florence2_llm, + florence2_processor, + stop_sign_image, + "", + max_tokens=30, + ) + text = result[""].lower() + assert ( + "car" in text or "stop" in text + ), f" output missing expected content: {text!r}" + + def test_detailed_caption( + self, florence2_llm, florence2_processor, stop_sign_image + ): + result = _run_task( + florence2_llm, + florence2_processor, + stop_sign_image, + "", + max_tokens=80, + ) + text = result[""].lower() + # Must mention the car and give some background detail — guards against the + # KV-cache encoder_seq_lens regression that previously produced garbled output. + assert "car" in text, f" missing 'car': {text!r}" + assert len(text.split()) >= 10, f" too short: {text!r}" + + def test_more_detailed_caption( + self, florence2_llm, florence2_processor, stop_sign_image + ): + result = _run_task( + florence2_llm, + florence2_processor, + stop_sign_image, + "", + max_tokens=100, + ) + text = result[""].lower() + assert ( + "stop sign" in text or "sign" in text + ), f" missing 'stop sign': {text!r}" + assert len(text.split()) >= 10, f" too short: {text!r}" + + # ------------------------------------------------------------------ + # Structured-output tasks — check schema and key labels + # ------------------------------------------------------------------ + + def test_object_detection( + self, florence2_llm, florence2_processor, stop_sign_image + ): + result = _run_task( + florence2_llm, florence2_processor, stop_sign_image, "", max_tokens=300 + ) + od = result[""] + assert "bboxes" in od and "labels" in od + assert len(od["bboxes"]) == len(od["labels"]) > 0 + # Each bbox must be a 4-element list with non-negative coords + for bbox in od["bboxes"]: + assert len(bbox) == 4 and all(c >= 0 for c in bbox) + labels = od["labels"] + assert ( + "stop sign" in labels + ), f"Expected 'stop sign' in OD labels, got: {labels}" + assert ( + "car" in labels or "building" in labels + ), f"Expected common objects in OD labels, got: {labels}" + + def test_dense_region_caption( + self, florence2_llm, florence2_processor, stop_sign_image + ): + result = _run_task( + florence2_llm, + florence2_processor, + stop_sign_image, + "", + max_tokens=250, + ) + drc = result[""] + assert "bboxes" in drc and "labels" in drc + assert len(drc["bboxes"]) == len(drc["labels"]) > 0 + assert ( + "stop sign" in drc["labels"] + ), f"Expected 'stop sign' in dense captions, got: {drc['labels']}" + + def test_region_proposal(self, florence2_llm, florence2_processor, stop_sign_image): + result = _run_task( + florence2_llm, + florence2_processor, + stop_sign_image, + "", + max_tokens=100, + ) + rp = result[""] + assert "bboxes" in rp and "labels" in rp + assert len(rp["bboxes"]) > 0 + # Region proposal labels are always empty strings + assert all(label == "" for label in rp["labels"]) + + def test_ocr_with_region(self, florence2_llm, florence2_processor, stop_sign_image): + result = _run_task( + florence2_llm, + florence2_processor, + stop_sign_image, + "", + max_tokens=250, + ) + ocr = result[""] + assert "quad_boxes" in ocr and "labels" in ocr + assert len(ocr["quad_boxes"]) == len(ocr["labels"]) > 0 + # Each quad box must be 8 coords + for quad in ocr["quad_boxes"]: + assert len(quad) == 8 + # "STOP" is the most prominent text in the image + joined = " ".join(ocr["labels"]) + assert ( + "STOP" in joined + ), f"Expected 'STOP' in OCR_WITH_REGION labels, got: {joined!r}" + + def test_caption_to_phrase_grounding( + self, florence2_llm, florence2_processor, stop_sign_image + ): + result = _run_task( + florence2_llm, + florence2_processor, + stop_sign_image, + "", + text_input="A stop sign on a street corner.", + max_tokens=80, + ) + cpg = result[""] + assert "bboxes" in cpg and "labels" in cpg + assert len(cpg["bboxes"]) > 0 + assert any( + "stop sign" in lbl.lower() for lbl in cpg["labels"] + ), f"Expected 'stop sign' grounded, got labels: {cpg['labels']}" + + # ------------------------------------------------------------------ + # Batch tests + # ------------------------------------------------------------------ + + def test_batch_inference(self, florence2_llm, florence2_processor, stop_sign_image): + """Multiple prompts in one batch must all produce non-empty output.""" + from vllm import SamplingParams + + params = SamplingParams( + temperature=0.0, max_tokens=30, skip_special_tokens=False + ) + prompts = [ + {"prompt": "", "multi_modal_data": {"image": stop_sign_image}}, + { + "prompt": "", + "multi_modal_data": {"image": stop_sign_image}, + }, + ] + outputs = florence2_llm.generate(prompts, sampling_params=params) + assert all(len(o.outputs[0].text) > 0 for o in outputs) diff --git a/tests/test_model_initialization.py b/tests/test_model_initialization.py index 2a44aec..7f01bec 100644 --- a/tests/test_model_initialization.py +++ b/tests/test_model_initialization.py @@ -1,7 +1,8 @@ """Tests for BART model initialization.""" import pytest -import torch +import vllm_bart_plugin +vllm_bart_plugin.register_bart_model() from vllm import LLM @@ -42,6 +43,7 @@ def test_model_with_custom_config(self, small_model_name): except Exception as e: pytest.fail(f"Failed to load model with config: {e}") + @pytest.mark.slow def test_model_class_initialization(self): """Test that model class can be instantiated.""" from vllm_bart_plugin.bart import BartForConditionalGeneration @@ -53,7 +55,6 @@ def test_model_class_initialization(self): model_config = ModelConfig( model="facebook/bart-large-cnn", - task="generate", tokenizer="facebook/bart-large-cnn", tokenizer_mode="auto", trust_remote_code=False, @@ -65,7 +66,6 @@ def test_model_class_initialization(self): cache_config = CacheConfig( block_size=16, gpu_memory_utilization=0.3, - swap_space_bytes=0, cache_dtype="auto", ) @@ -77,7 +77,9 @@ def test_model_class_initialization(self): # Try to instantiate the model try: - model = BartForConditionalGeneration(vllm_config=vllm_config) + from vllm.config import set_current_vllm_config + with set_current_vllm_config(vllm_config): + model = BartForConditionalGeneration(vllm_config=vllm_config) assert model is not None assert hasattr(model, 'model') assert hasattr(model, 'lm_head') @@ -92,13 +94,14 @@ def test_model_has_required_methods(self): 'forward', 'compute_logits', 'load_weights', - 'get_multimodal_embeddings', + 'embed_multimodal', ] for method in required_methods: assert hasattr(BartForConditionalGeneration, method), \ f"Model missing required method: {method}" + @pytest.mark.slow def test_encoder_decoder_structure(self): """Test that BART has proper encoder-decoder structure.""" from vllm_bart_plugin.bart import BartModel, BartEncoder, BartDecoder @@ -109,7 +112,6 @@ def test_encoder_decoder_structure(self): model_config = ModelConfig( model="facebook/bart-large-cnn", - task="generate", tokenizer="facebook/bart-large-cnn", tokenizer_mode="auto", trust_remote_code=False, @@ -121,7 +123,6 @@ def test_encoder_decoder_structure(self): cache_config = CacheConfig( block_size=16, gpu_memory_utilization=0.3, - swap_space_bytes=0, cache_dtype="auto", ) @@ -131,7 +132,9 @@ def test_encoder_decoder_structure(self): load_config=LoadConfig(), ) - model = BartModel(vllm_config=vllm_config) + from vllm.config import set_current_vllm_config + with set_current_vllm_config(vllm_config): + model = BartModel(vllm_config=vllm_config) assert hasattr(model, 'encoder') assert hasattr(model, 'decoder') diff --git a/vllm_bart_plugin/__init__.py b/vllm_bart_plugin/__init__.py index 0648366..df079d7 100644 --- a/vllm_bart_plugin/__init__.py +++ b/vllm_bart_plugin/__init__.py @@ -34,10 +34,10 @@ def register_bart_model() -> None: "vllm_bart_plugin.florence2:Florence2ForConditionalGeneration", ) - logger.info("Successfully registered BART model with vLLM") + logger.info("Successfully registered BART and Florence2 models with vLLM") except Exception as e: - logger.error(f"Failed to register BART model: {e}") + logger.error(f"Failed to register BART and Florence2 models: {e}") raise diff --git a/vllm_bart_plugin/bart.py b/vllm_bart_plugin/bart.py index f3d6cd5..a21e3fc 100644 --- a/vllm_bart_plugin/bart.py +++ b/vllm_bart_plugin/bart.py @@ -29,7 +29,8 @@ from torch import nn from transformers import BartConfig from transformers.utils import logging -from vllm.attention.layer import Attention, AttentionType +from vllm.model_executor.layers.attention import Attention +from vllm.v1.attention.backend import AttentionType from vllm.config import CacheConfig, VllmConfig from vllm.config.lora import LoRAConfig from vllm.config.multimodal import BaseDummyOptions @@ -78,7 +79,7 @@ EncDecMultiModalProcessor, PromptUpdate, ) -from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.multimodal.processing.dummy_inputs import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils.collection_utils import is_list_of @@ -927,6 +928,9 @@ def get_mm_max_tokens_per_item( config = self.get_hf_config() return {"text": config.max_position_embeddings} + def get_data_parser(self) -> "MultiModalDataParser": + return TextDataParser() + class BartDummyInputsBuilder(BaseDummyInputsBuilder[BartProcessingInfo]): """Builds dummy inputs for profiling BART models.""" @@ -1107,7 +1111,7 @@ def _get_prompt_updates( ) ] - def _get_data_parser(self) -> MultiModalDataParser: + def build_data_parser(self) -> MultiModalDataParser: return TextDataParser() @@ -1152,8 +1156,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config.vocab_size, config.d_model, embed_scale=embed_scale ) # Bias added to logits after lm_head, matching HuggingFace approach - self.register_buffer("final_logits_bias", - torch.zeros((1, config.vocab_size))) + self.register_buffer("final_logits_bias", torch.zeros((1, config.vocab_size))) self.logits_processor = LogitsProcessor( self.unpadded_vocab_size, config.vocab_size ) diff --git a/vllm_bart_plugin/florence2.py b/vllm_bart_plugin/florence2.py index 3624b1f..b215dc1 100644 --- a/vllm_bart_plugin/florence2.py +++ b/vllm_bart_plugin/florence2.py @@ -1,66 +1,56 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Literal, TypedDict, OrderedDict import math from collections.abc import Iterable, Mapping, Sequence -from typing import Any - -import torch.nn.functional as F -from einops import rearrange +from typing import Literal, TypedDict import torch from torch import nn -from transformers import BartConfig, BatchFeature, BartTokenizer, PretrainedConfig -from transformers.utils import logging - -from vllm.attention.layer import Attention, AttentionType -from vllm.model_executor.layers.attention.cross_attention import CrossAttention -from vllm.model_executor.layers.attention.mm_encoder_attention import MMEncoderAttention -from vllm.config import CacheConfig, VllmConfig -from vllm.config.lora import LoRAConfig -from vllm.config.multimodal import BaseDummyOptions -from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.linear import ( - ColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear, +from transformers import ( + BartTokenizer, + BatchFeature, + Florence2Config, + Florence2Processor, ) -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import QuantizationConfig -from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, - VocabParallelEmbedding, +from transformers.models.florence2.modeling_florence2 import ( + Florence2MultiModalProjector, + Florence2VisionBackbone, ) +from vllm.config import VllmConfig +from vllm.config.multimodal import BaseDummyOptions +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.multimodal import MULTIMODAL_REGISTRY, ModalityData +from vllm.model_executor.models.interfaces import ( + MultiModalEmbeddings, + SupportsMultiModal, +) +from vllm.model_executor.models.utils import ( + AutoWeightsLoader, + WeightsMapper, +) +from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, ) -from vllm.multimodal.parse import ( - ModalityDataItems, - ModalityDataParser, - MultiModalDataItems, - MultiModalDataParser, - ProcessorBatchItems, -) +from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import ( BaseProcessingInfo, EncDecMultiModalProcessor, - PromptUpdate, - PromptInsertion, PromptIndexTargets, + PromptInsertion, + PromptUpdate, ) -from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.multimodal.processing.dummy_inputs import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors -from vllm.utils.collection_utils import is_list_of - -from vllm.model_executor.models.interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsQuant -from vllm.model_executor.models.utils import AutoWeightsLoader, WeightsMapper, cast_overflow_tensors, maybe_prefix, flatten_bn -from vllm_bart_plugin.bart import BartDecoder, BartEncoder, BartParallelLMHead, BartScaledWordEmbedding +from vllm_bart_plugin.bart import ( + BartDecoder, + BartEncoder, + BartParallelLMHead, + BartScaledWordEmbedding, +) class Florence2ImagePixelInputs(TypedDict): @@ -69,547 +59,6 @@ class Florence2ImagePixelInputs(TypedDict): """Shape: (batch_size, num_channel, height, width)""" -# ViT implementation are all copied from -# https://huggingface.co/microsoft/Florence-2-base/blob/main/modeling_florence2.py -class LearnedAbsolutePositionEmbedding2D(nn.Module): - """ - This module learns positional embeddings up to a fixed maximum size. - """ - - def __init__(self, embedding_dim=256, num_pos=50): - super().__init__() - self.row_embeddings = nn.Embedding(num_pos, embedding_dim // 2) - self.column_embeddings = nn.Embedding( - num_pos, embedding_dim - (embedding_dim // 2)) - - def forward(self, pixel_values): - """ - pixel_values: (batch_size, height, width, num_channels) - returns: (batch_size, height, width, embedding_dim * 2) - """ - if len(pixel_values.shape) != 4: - raise ValueError('pixel_values must be a 4D tensor') - height, width = pixel_values.shape[1:3] - width_values = torch.arange(width, device=pixel_values.device) - height_values = torch.arange(height, device=pixel_values.device) - x_emb = self.column_embeddings(width_values) - y_emb = self.row_embeddings(height_values) - # (height, width, embedding_dim * 2) - pos = torch.cat([ - x_emb.unsqueeze(0).repeat(height, 1, 1), - y_emb.unsqueeze(1).repeat(1, width, 1) - ], - dim=-1) - # (embedding_dim * 2, height, width) - pos = pos.permute(2, 0, 1) - pos = pos.unsqueeze(0) - # (batch_size, embedding_dim * 2, height, width) - pos = pos.repeat(pixel_values.shape[0], 1, 1, 1) - # (batch_size, height, width, embedding_dim * 2) - pos = pos.permute(0, 2, 3, 1) - return pos - - -class PositionalEmbeddingCosine1D(nn.Module): - """ - This class implements a very simple positional encoding. It follows closely - the encoder from the link below: - https://pytorch.org/tutorials/beginner/translation_transformer.html - Args: - embed_dim: The dimension of the embeddings. - dropout_prob: The dropout probability. - max_seq_len: The maximum length to precompute the positional encodings. - """ - - def __init__(self, embed_dim: int = 512, max_seq_len: int = 1024) -> None: - super().__init__() - self.embed_dim = embed_dim - self.max_seq_len = max_seq_len - # Generate the sinusoidal arrays. - factor = math.log(10000) - denominator = torch.exp(-factor * torch.arange(0, self.embed_dim, 2) / - self.embed_dim) - # Matrix where rows correspond to a positional embedding as a function - # of the position index (i.e., the row index). - frequencies = \ - torch.arange(0, self.max_seq_len) \ - .reshape(self.max_seq_len, 1) * denominator - pos_idx_to_embed = torch.zeros((self.max_seq_len, self.embed_dim)) - # Populate uneven entries. - pos_idx_to_embed[:, 0::2] = torch.sin(frequencies) - pos_idx_to_embed[:, 1::2] = torch.cos(frequencies) - # Save the positional embeddings in a constant buffer. - # self.register_buffer("pos_idx_to_embed", pos_idx_to_embed) - self.pos_idx_to_embed = nn.Parameter(pos_idx_to_embed, - requires_grad=False) - - def forward(self, seq_embeds: torch.Tensor) -> torch.Tensor: - """ - Args: - seq_embeds: The sequence embeddings in order. Allowed size: - 1. [T, D], where T is the length of the sequence, and D is the - frame embedding dimension. - 2. [B, T, D], where B is the batch size and T and D are the - same as above. - Returns a tensor of with the same dimensions as the input: i.e., - [1, T, D] or [T, D]. - """ - shape_len = len(seq_embeds.shape) - assert 2 <= shape_len <= 3 - len_seq = seq_embeds.size(-2) - assert len_seq <= self.max_seq_len - pos_embeds = self.pos_idx_to_embed[0:seq_embeds.size(-2), :] - # Adapt pre-computed positional embeddings to the input. - if shape_len == 3: - pos_embeds = pos_embeds.view( - (1, pos_embeds.size(0), pos_embeds.size(1))) - return pos_embeds - - -class MySequential(nn.Sequential): - - def forward(self, *inputs): - for module in self._modules.values(): - if isinstance(inputs, tuple): - inputs = module(*inputs) - else: - inputs = module(inputs) - return inputs - - -class PreNorm(nn.Module): - - def __init__(self, norm, fn): - super().__init__() - self.norm = norm - self.fn = fn - - def forward(self, x, *args, **kwargs): - shortcut = x - if self.norm is not None: - x, size = self.fn(self.norm(x), *args, **kwargs) - else: - x, size = self.fn(x, *args, **kwargs) - - x = shortcut + x - - return x, size - - -class Mlp(nn.Module): - - def __init__( - self, - in_features, - hidden_features=None, - out_features=None, - act_layer=nn.GELU, - ): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.net = nn.Sequential( - OrderedDict([("fc1", nn.Linear(in_features, hidden_features)), - ("act", act_layer()), - ("fc2", nn.Linear(hidden_features, out_features))])) - - def forward(self, x, size): - return self.net(x), size - - -class DepthWiseConv2d(nn.Module): - - def __init__( - self, - dim_in, - kernel_size, - padding, - stride, - bias=True, - ): - super().__init__() - self.dw = nn.Conv2d(dim_in, - dim_in, - kernel_size=kernel_size, - padding=padding, - groups=dim_in, - stride=stride, - bias=bias) - - def forward(self, x, size): - B, N, C = x.shape - H, W = size - assert N == H * W - - x = self.dw(x.transpose(1, 2).view(B, C, H, W)) - size = (x.size(-2), x.size(-1)) - x = x.flatten(2).transpose(1, 2) - return x, size - - -class ConvEmbed(nn.Module): - """ Image to Patch Embedding - """ - - def __init__(self, - patch_size=7, - in_chans=3, - embed_dim=64, - stride=4, - padding=2, - norm_layer=None, - pre_norm=True): - super().__init__() - self.patch_size = patch_size - - self.proj = nn.Conv2d(in_chans, - embed_dim, - kernel_size=patch_size, - stride=stride, - padding=padding) - - dim_norm = in_chans if pre_norm else embed_dim - self.norm = norm_layer(dim_norm) if norm_layer else None - - self.pre_norm = pre_norm - - def forward(self, x, size): - H, W = size - if len(x.size()) == 3: - if self.norm and self.pre_norm: - x = self.norm(x) - x = rearrange(x, 'b (h w) c -> b c h w', h=H, w=W) - - x = self.proj(x) - - _, _, H, W = x.shape - x = rearrange(x, 'b c h w -> b (h w) c') - if self.norm and not self.pre_norm: - x = self.norm(x) - - return x, (H, W) - - -class ChannelAttention(nn.Module): - - def __init__(self, dim, groups=8, qkv_bias=True): - super().__init__() - - self.groups = groups - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.proj = nn.Linear(dim, dim) - - def forward(self, x, size): - B, N, C = x.shape - - qkv = self.qkv(x).reshape(B, N, 3, self.groups, - C // self.groups).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0], qkv[1], qkv[2] - - q = q * (float(N)**-0.5) - attention = q.transpose(-1, -2) @ k - attention = attention.softmax(dim=-1) - x = (attention @ v.transpose(-1, -2)).transpose(-1, -2) - x = x.transpose(1, 2).reshape(B, N, C) - x = self.proj(x) - return x, size - - -class ChannelBlock(nn.Module): - - def __init__(self, - dim, - groups, - mlp_ratio=4., - qkv_bias=True, - drop_path_rate=0., - act_layer=nn.GELU, - norm_layer=nn.LayerNorm, - conv_at_attn=True, - conv_at_ffn=True): - super().__init__() - - self.conv1 = PreNorm(None, DepthWiseConv2d( - dim, 3, 1, 1)) if conv_at_attn else None - self.channel_attn = PreNorm( - norm_layer(dim), - ChannelAttention(dim, groups=groups, qkv_bias=qkv_bias), - ) - self.conv2 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, - 1)) if conv_at_ffn else None - self.ffn = PreNorm( - norm_layer(dim), - Mlp(in_features=dim, - hidden_features=int(dim * mlp_ratio), - act_layer=act_layer), - ) - - def forward(self, x, size): - if self.conv1: - x, size = self.conv1(x, size) - x, size = self.channel_attn(x, size) - - if self.conv2: - x, size = self.conv2(x, size) - x, size = self.ffn(x, size) - - return x, size - - -def window_partition(x, window_size: int): - B, H, W, C = x.shape - x = x.view(B, H // window_size, window_size, W // window_size, window_size, - C) - windows = x.permute(0, 1, 3, 2, 4, - 5).contiguous().view(-1, window_size, window_size, C) - return windows - - -def window_reverse(windows, batch_size: int, window_size: int, H: int, W: int): - B = batch_size - - x = windows.view(B, H // window_size, W // window_size, window_size, - window_size, -1) - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) - return x - - -class WindowAttention(nn.Module): - - def __init__(self, dim, num_heads, window_size, qkv_bias=True): - - super().__init__() - self.dim = dim - self.window_size = window_size - self.num_heads = num_heads - head_dim = dim // num_heads - self.scale = float(head_dim)**-0.5 - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.proj = nn.Linear(dim, dim) - - self.softmax = nn.Softmax(dim=-1) - - def forward(self, x, size): - - H, W = size - B, L, C = x.shape - assert L == H * W, "input feature has wrong size" - - x = x.view(B, H, W, C) - - pad_l = pad_t = 0 - pad_r = (self.window_size - W % self.window_size) % self.window_size - pad_b = (self.window_size - H % self.window_size) % self.window_size - x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) - _, Hp, Wp, _ = x.shape - - x = window_partition(x, self.window_size) - x = x.view(-1, self.window_size * self.window_size, C) - - # W-MSA/SW-MSA - # attn_windows = self.attn(x_windows) - - B_, N, C = x.shape - qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, - C // self.num_heads).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0], qkv[1], qkv[2] - - q = q * self.scale - attn = (q @ k.transpose(-2, -1)) - attn = self.softmax(attn) - - x = (attn @ v).transpose(1, 2).reshape(B_, N, C) - x = self.proj(x) - - # merge windows - x = x.view(-1, self.window_size, self.window_size, C) - x = window_reverse(x, B, self.window_size, Hp, Wp) - - if pad_r > 0 or pad_b > 0: - x = x[:, :H, :W, :].contiguous() - - x = x.view(B, H * W, C) - - return x, size - - -class SpatialBlock(nn.Module): - - def __init__(self, - dim, - num_heads, - window_size, - mlp_ratio=4., - qkv_bias=True, - drop_path_rate=0., - act_layer=nn.GELU, - norm_layer=nn.LayerNorm, - conv_at_attn=True, - conv_at_ffn=True): - super().__init__() - - self.conv1 = PreNorm(None, DepthWiseConv2d( - dim, 3, 1, 1)) if conv_at_attn else None - self.window_attn = PreNorm( - norm_layer(dim), - WindowAttention(dim, num_heads, window_size, qkv_bias=qkv_bias), - ) - self.conv2 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, - 1)) if conv_at_ffn else None - self.ffn = PreNorm( - norm_layer(dim), - Mlp(in_features=dim, - hidden_features=int(dim * mlp_ratio), - act_layer=act_layer), - ) - - def forward(self, x, size): - if self.conv1: - x, size = self.conv1(x, size) - x, size = self.window_attn(x, size) - - if self.conv2: - x, size = self.conv2(x, size) - x, size = self.ffn(x, size) - return x, size - - -class DaViT(nn.Module): - - def __init__( - self, - in_chans=3, - num_classes=1000, - depths=(1, 1, 3, 1), - patch_size=(7, 2, 2, 2), - patch_stride=(4, 2, 2, 2), - patch_padding=(3, 0, 0, 0), - patch_prenorm=(False, False, False, False), - embed_dims=(64, 128, 192, 256), - num_heads=(3, 6, 12, 24), - num_groups=(3, 6, 12, 24), - window_size=7, - mlp_ratio=4., - qkv_bias=True, - drop_path_rate=0.1, - norm_layer=nn.LayerNorm, - enable_checkpoint=False, - conv_at_attn=True, - conv_at_ffn=True, - ): - super().__init__() - - self.num_classes = num_classes - self.embed_dims = embed_dims - self.num_heads = num_heads - self.num_groups = num_groups - self.num_stages = len(self.embed_dims) - self.enable_checkpoint = enable_checkpoint - assert self.num_stages == len(self.num_heads) == len(self.num_groups) - - num_stages = len(embed_dims) - dpr = [ - x.item() for x in torch.linspace(0, drop_path_rate, - sum(depths) * 2) - ] - - depth_offset = 0 - convs = [] - blocks = [] - for i in range(num_stages): - conv_embed = ConvEmbed( - patch_size=patch_size[i], - stride=patch_stride[i], - padding=patch_padding[i], - in_chans=in_chans if i == 0 else self.embed_dims[i - 1], - embed_dim=self.embed_dims[i], - norm_layer=norm_layer, - pre_norm=patch_prenorm[i]) - convs.append(conv_embed) - - block = MySequential(*[ - MySequential( - OrderedDict([('spatial_block', - SpatialBlock( - embed_dims[i], - num_heads[i], - window_size, - drop_path_rate=dpr[depth_offset + j * 2], - qkv_bias=qkv_bias, - mlp_ratio=mlp_ratio, - conv_at_attn=conv_at_attn, - conv_at_ffn=conv_at_ffn, - )), - ('channel_block', - ChannelBlock( - embed_dims[i], - num_groups[i], - drop_path_rate=dpr[depth_offset + j * 2 + - 1], - qkv_bias=qkv_bias, - mlp_ratio=mlp_ratio, - conv_at_attn=conv_at_attn, - conv_at_ffn=conv_at_ffn, - ))])) for j in range(depths[i]) - ]) - blocks.append(block) - depth_offset += depths[i] * 2 - - self.convs = nn.ModuleList(convs) - self.blocks = nn.ModuleList(blocks) - - self.avgpool = nn.AdaptiveAvgPool1d(1) - - @property - def dim_out(self): - return self.embed_dims[-1] - - def forward_features_unpool(self, x): - """ - forward until avg pooling - Args: - x (_type_): input image tensor - """ - input_size = (x.size(2), x.size(3)) - for conv, block in zip(self.convs, self.blocks): - x, input_size = conv(x, input_size) - x, input_size = block(x, input_size) - return x - - def forward_features(self, x): - x = self.forward_features_unpool(x) - - # (batch_size, num_tokens, token_dim) - x = self.avgpool(x.transpose(1, 2)) - # (batch_size, 1, num_tokens) - x = torch.flatten(x, 1) - x = self.norms(x) - - return x - - def forward(self, x): - x = self.forward_features(x) - x = self.head(x) - return x - - @classmethod - def from_config(cls, config): - return cls( - depths=config.depths, - embed_dims=config.dim_embed, - num_heads=config.num_heads, - num_groups=config.num_groups, - patch_size=config.patch_size, - patch_stride=config.patch_stride, - patch_padding=config.patch_padding, - patch_prenorm=config.patch_prenorm, - drop_path_rate=config.drop_path_rate, - window_size=config.window_size, - ) - - -# Language backbone and processor implementation class Florence2LanguageModel(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -624,14 +73,18 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.vocab_size = config.vocab_size self.shared = BartScaledWordEmbedding(self.vocab_size, config.d_model) - self.encoder = BartEncoder(config, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.encoder") - self.decoder = BartDecoder(config, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.decoder") + self.encoder = BartEncoder( + config, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.encoder", + ) + self.decoder = BartDecoder( + config, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.decoder", + ) if self.config.tie_word_embeddings: self.encoder.embed_tokens.weight = self.shared.weight @@ -664,18 +117,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config self.config = config - self.model = Florence2LanguageModel(vllm_config=vllm_config, - prefix=f"{prefix}.model") - embed_scale = math.sqrt( - config.d_model) if config.scale_embedding else 1.0 + self.model = Florence2LanguageModel( + vllm_config=vllm_config, prefix=f"{prefix}.model" + ) + embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 self.vocab_size = config.vocab_size - self.lm_head = BartParallelLMHead(self.vocab_size, - config.d_model, - embed_scale=embed_scale) + self.lm_head = BartParallelLMHead( + self.vocab_size, config.d_model, embed_scale=embed_scale + ) - self.logits_processor = LogitsProcessor(self.vocab_size, - config.vocab_size) + self.logits_processor = LogitsProcessor(self.vocab_size, config.vocab_size) if self.config.tie_word_embeddings: self.lm_head.tie_weights(self.model.shared) @@ -689,10 +141,12 @@ def forward( # num_encoder_outputs: int | None = None, **kwargs, ) -> torch.Tensor: - return self.model(input_ids, - positions, - inputs_embeds=inputs_embeds, - encoder_outputs=encoder_outputs) + return self.model( + input_ids, + positions, + inputs_embeds=inputs_embeds, + encoder_outputs=encoder_outputs, + ) def get_encoder_outputs( self, @@ -719,8 +173,7 @@ def compute_logits( logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("encoder_attn.kv_proj", "encoder_attn.k_proj", "k"), @@ -733,7 +186,7 @@ def load_weights(self, weights: Iterable[tuple[str, params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -747,8 +200,7 @@ def load_weights(self, weights: Iterable[tuple[str, if self.config.tie_word_embeddings and "embed_tokens" in name: continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params @@ -756,22 +208,21 @@ def load_weights(self, weights: Iterable[tuple[str, class Florence2ProcessingInfo(BaseProcessingInfo): - def get_hf_config(self): + def get_hf_config(self) -> Florence2Config: return self.ctx.get_hf_config() - def get_hf_processor(self): + def get_hf_processor(self) -> Florence2Processor: return self.ctx.get_hf_processor() def get_supported_mm_limits(self) -> Mapping[str, int | None]: return {"image": 1} def get_num_image_tokens(self) -> int: - processor_config = self.ctx.get_hf_image_processor_config() - return processor_config["image_seq_length"] + processor = self.get_hf_processor() + return processor.num_image_tokens -class Florence2DummyInputsBuilder( - BaseDummyInputsBuilder[Florence2ProcessingInfo]): +class Florence2DummyInputsBuilder(BaseDummyInputsBuilder[Florence2ProcessingInfo]): def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: return "" @@ -784,18 +235,31 @@ def get_dummy_mm_data( ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) - target_width = target_height = self.info.get_hf_config().projection_dim + target_width = target_height = ( + self.info.get_hf_config().vision_config.projection_dim + ) return { - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images) + "image": self._get_dummy_images( + width=target_width, height=target_height, num_images=num_images + ) } -class Florence2MultiModalProcessor( - EncDecMultiModalProcessor[Florence2ProcessingInfo]): +class Florence2MultiModalProcessor(EncDecMultiModalProcessor[Florence2ProcessingInfo]): + + def __init__(self, info, dummy_inputs, *, cache=None) -> None: + super().__init__(info, dummy_inputs, cache=cache) + # Florence2Config does not expose decoder_start_token_id at the + # top level (it lives in text_config), so vLLM falls back to BOS + # (token 0) and incorrectly prepends it to the decoder prompt. + # Patch the top-level hf_config so vLLM's _prepare_decoder_input_ids + # sees the real value (EOS / token 2) and leaves our prompt intact. + hf_config = info.get_hf_config() + if getattr(hf_config, "decoder_start_token_id", None) is None: + hf_config.decoder_start_token_id = ( + hf_config.text_config.decoder_start_token_id + ) def _hf_processor_applies_updates( self, @@ -804,7 +268,10 @@ def _hf_processor_applies_updates( hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], ) -> bool: - return False + # The Florence2Processor already inserts image_token_id placeholders + # into the input_ids (577 tokens for a 768x768 image), so we tell + # vllm to find those existing placeholders rather than insert new ones. + return bool(mm_items.get_all_counts().get("image", 0)) def create_encoder_prompt( self, @@ -818,7 +285,16 @@ def create_decoder_prompt( prompt: str | list[int], mm_data: MultiModalDataDict, ) -> str | list[int]: - return [self.info.get_hf_config().eos_token_id] + text_config = self.info.get_hf_config().text_config + # Decoder prompt mirrors what transformers does before open-ended + # generation: start with decoder_start_token_id (, token 2), + # then include forced_bos_token_id (, token 0) so that vLLM + # generates from the same position as transformers step 2. + decoder_prompt = [text_config.decoder_start_token_id] + forced_bos = getattr(text_config, "forced_bos_token_id", None) + if forced_bos is not None: + decoder_prompt.append(forced_bos) + return decoder_prompt def _apply_hf_processor_tokens_only( self, @@ -841,14 +317,15 @@ def _call_hf_processor( ) -> BatchFeature: if mm_data: processed_outputs = super()._call_hf_processor( - prompt, mm_data, mm_kwargs, tok_kwargs) + prompt, mm_data, mm_kwargs, tok_kwargs + ) else: hf_processor = self.info.get_hf_processor() tokenizer = hf_processor.tokenizer prompt = hf_processor._construct_prompts([prompt])[0] - processed_outputs = tokenizer(prompt, - add_special_tokens=True, - return_tensors="pt") + processed_outputs = tokenizer( + prompt, add_special_tokens=True, return_tensors="pt" + ) processed_outputs["encoder_input_ids"] = processed_outputs["input_ids"] return processed_outputs @@ -868,16 +345,38 @@ def _get_prompt_updates( hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: - hf_config = self.info.get_hf_config() - pad_token_id = hf_config.pad_token_id - num_image_tokens = self.info.get_num_image_tokens() - image_tokens = [pad_token_id] * num_image_tokens + # The placeholder must cover the FULL encoder input sequence (image + # tokens + text/task tokens) so that vLLM's _get_encoder_seq_lens + # computes the correct value for cross-attention KV cache allocation. + # Using only the image token count (577) would cause cross-attention + # to read only 577/590 K/V pairs, skipping the task-prompt tokens. + # + # With _hf_processor_applies_updates=True, vLLM detects the existing + # token sequence rather than inserting new tokens. By setting the + # insertion to the full encoder_input_ids sequence, the detected + # placeholder range covers all 590 encoder tokens. + insertion: list[int] + image_items = out_mm_kwargs.get("image", []) + if image_items: + item_data = image_items[0].get_data() + enc_ids = item_data.get("encoder_input_ids") + if enc_ids is not None: + insertion = enc_ids.tolist() + else: + # Cache hit: encoder_input_ids not available; fall back. + hf_config = self.info.get_hf_config() + insertion = [ + hf_config.image_token_id + ] * self.info.get_num_image_tokens() + else: + hf_config = self.info.get_hf_config() + insertion = [hf_config.image_token_id] * self.info.get_num_image_tokens() return [ PromptInsertion( modality="image", target=PromptIndexTargets.start(), - insertion=image_tokens, + insertion=insertion, ) ] @@ -885,8 +384,17 @@ def _get_prompt_updates( @MULTIMODAL_REGISTRY.register_processor( Florence2MultiModalProcessor, info=Florence2ProcessingInfo, - dummy_inputs=Florence2DummyInputsBuilder) + dummy_inputs=Florence2DummyInputsBuilder, +) class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "model.vision_tower.": "vision_tower.", + "model.multi_modal_projector.": "multi_modal_projector.", + "model.language_model.": "language_model.model.", + "lm_head.": "language_model.lm_head.", + } + ) @classmethod def get_placeholder_str(cls, modality: str, i: int) -> str | None: @@ -901,63 +409,39 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): processor_config = vllm_config.model_config.hf_image_processor_config self.config = config - self.vision_config = config.vision_config self.processor_config = processor_config - assert config.vision_config.model_type == 'davit', ( - 'only DaViT is supported for now') - self.vision_tower = DaViT.from_config(config=config.vision_config) - self._build_image_projection_layers(config) + assert config.vision_config.model_type == "florence_vision", ( + f"only Florence Vision is supported for now. " + f"Received model type: {config.vision_config.model_type}" + ) + self.vision_tower = Florence2VisionBackbone(config.vision_config) + self.multi_modal_projector = Florence2MultiModalProjector(config) self.language_model = Florence2LanguageForConditionalGeneration( vllm_config=vllm_config.with_hf_config(config.text_config), prefix=f"{prefix}.language_model", ) - self.pad_token_id = config.pad_token_id - - def _build_image_projection_layers(self, config: PretrainedConfig): - image_dim_out = config.vision_config.dim_embed[-1] - dim_projection = config.vision_config.projection_dim - self.image_projection = nn.Parameter( - torch.empty(image_dim_out, dim_projection)) - self.image_proj_norm = nn.LayerNorm(dim_projection) - image_pos_embed_config = config.vision_config.image_pos_embed - if image_pos_embed_config['type'] == 'learned_abs_2d': - self.image_pos_embed = LearnedAbsolutePositionEmbedding2D( - embedding_dim=image_dim_out, - num_pos=image_pos_embed_config['max_pos_embeddings']) - else: - raise NotImplementedError("Florence2 only supports learned_abs_2d " - "as image position embedding.") - - self.image_feature_source = config.vision_config.image_feature_source - - # temporal embedding - visual_temporal_embedding_config = ( - self.vision_config.visual_temporal_embedding) - if visual_temporal_embedding_config['type'] == 'COSINE': - self.visual_temporal_embed = PositionalEmbeddingCosine1D( - embed_dim=image_dim_out, - max_seq_len=visual_temporal_embedding_config[ - 'max_temporal_embeddings']) - else: - raise NotImplementedError( - 'Florence2 only supports COSINE as temporal embedding.') + self.pad_token_id = config.text_config.pad_token_id def _validate_pixel_values( self, data: torch.Tensor | list[torch.Tensor] ) -> torch.Tensor | list[torch.Tensor]: + # The image processor config may use "size" or "crop_size"; fall back + # to reading the actual tensor shape if neither key is available. + cfg = self.processor_config + size = cfg.get("size") or cfg.get("crop_size") + if size is None: + return data - size = self.processor_config["size"] h, w = size["height"], size["width"] expected_dims = (3, h, w) def _validate_shape(d: torch.Tensor): actual_dims = tuple(d.shape) - if actual_dims != expected_dims: - expected_expr = tuple(*map(str, expected_dims)) raise ValueError( "The expected shape of pixel values per batch " - f"is {expected_expr}. You supplied {tuple(d.shape)}.") + f"is {expected_dims}. You supplied {actual_dims}." + ) for d in data: _validate_shape(d) @@ -965,112 +449,43 @@ def _validate_shape(d: torch.Tensor): return data def _parse_and_validate_image_input(self, **kwargs: object): - pixel_values: list[list[torch.Tensor]] | list[torch.Tensor] | torch.Tensor | None = kwargs.pop( - "pixel_values", None) - image_embeds: list[list[torch.Tensor]] | list[torch.Tensor] | torch.Tensor | None = kwargs.pop( - "image_embeds", None) + pixel_values = kwargs.pop("pixel_values", None) + image_embeds = kwargs.pop("image_embeds", None) if pixel_values is None and image_embeds is None: return None - if pixel_values is not None and image_embeds is not None: - raise ValueError( - "Both pixel values and image embeds are provided.") - + raise ValueError("Both pixel values and image embeds are provided.") if pixel_values is not None: return Florence2ImagePixelInputs( type="pixel_values", data=self._validate_pixel_values(pixel_values), ) - - if image_embeds is not None: - raise NotImplementedError - - raise AssertionError("This line should be unreachable.") + raise NotImplementedError("image_embeds not supported.") def _parse_and_validate_encoder_input(self, **kwargs: object) -> list[torch.Tensor]: encoder_input_ids = kwargs.get("encoder_input_ids", kwargs.get("input_ids")) - if encoder_input_ids is None: return [] - if not isinstance(encoder_input_ids, (torch.Tensor, list)): raise ValueError( - "Incorrect type of encoder input_ids. " - f"Got type: {type(encoder_input_ids)}" + f"Incorrect type of encoder input_ids. Got type: {type(encoder_input_ids)}" ) - - # Return as a list of tensors (one per item in the batch) if isinstance(encoder_input_ids, list): - # Already a list - ensure each item is valid - result = [] - for item in encoder_input_ids: - if isinstance(item, torch.Tensor): - if item.dim() == 0: - item = item.unsqueeze(0) - result.append(item) - else: - result.append(item) - return result - else: - # [1xD]xN times - return encoder_input_ids.unsqueeze(1).unbind(dim=0) + return [ + item.unsqueeze(0) if item.dim() == 0 else item + for item in encoder_input_ids + ] + return encoder_input_ids.unsqueeze(1).unbind(dim=0) def _encode_image(self, pixel_values: torch.Tensor) -> torch.Tensor: - dtype = next(self.vision_tower.parameters()).dtype - pixel_values = pixel_values.to(dtype) - - batch_size, T = pixel_values.size(0), 1 - x = self.vision_tower.forward_features_unpool(pixel_values) - if self.image_pos_embed is not None: - x = x.view(batch_size * T, -1, x.shape[-1]) - num_tokens = x.shape[-2] - h, w = int(num_tokens**0.5), int(num_tokens**0.5) - assert h * w == num_tokens, ( - 'only support square feature maps for now') - x = x.view(batch_size * T, h, w, x.shape[-1]) - pos_embed = self.image_pos_embed(x) - x = x + pos_embed - x = x.view(batch_size, T * h * w, x.shape[-1]) - - if self.visual_temporal_embed is not None: - visual_temporal_embed = self.visual_temporal_embed( - x.view(batch_size, T, -1, x.shape[-1])[:, :, 0]) - x = x.view(batch_size, T, -1, - x.shape[-1]) + visual_temporal_embed.view( - 1, T, 1, x.shape[-1]) - - x_feat_dict = {} - - spatial_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=2) - x_feat_dict['spatial_avg_pool'] = spatial_avg_pool_x - - temporal_avg_pool_x = x.view(batch_size, T, -1, - x.shape[-1]).mean(dim=1) - x_feat_dict['temporal_avg_pool'] = temporal_avg_pool_x - - x = x.view(batch_size, T, -1, x.shape[-1])[:, -1] - x_feat_dict['last_frame'] = x - - new_x = [] - for _image_feature_source in self.image_feature_source: - if _image_feature_source not in x_feat_dict: - raise ValueError('invalid image feature source: {}'.format( - _image_feature_source)) - new_x.append(x_feat_dict[_image_feature_source]) - - x = torch.cat(new_x, dim=1) - - x = x @ self.image_projection - x = self.image_proj_norm(x) - - return x + pixel_values = pixel_values.to(next(self.vision_tower.parameters()).dtype) + return self.multi_modal_projector(self.vision_tower(pixel_values)) def _process_image_input( - self, image_input: Florence2ImagePixelInputs) -> torch.Tensor: - assert image_input["type"] == "pixel_values" - pixel_values = image_input["data"] - return self._encode_image(pixel_values) + self, image_input: Florence2ImagePixelInputs + ) -> torch.Tensor: + return self._encode_image(image_input["data"]) def get_language_model(self) -> torch.nn.Module: return self.language_model @@ -1085,14 +500,10 @@ def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: if not encoder_input_ids_list: raise ValueError( - "encoder_input_ids_list is empty - this should not happen. " - "Check that multimodal data is being passed correctly." + "encoder_input_ids_list is empty - check multimodal data is being passed correctly." ) - # Process each encoder input separately and return a list of outputs - # NOTE (NickLucche): Basic encoder batching optimization: BART input sequences - # can have different lengths. Due to computational load of encoder being very - # low here, we batch all sequences to run a single forward by max_seq padding. + # Batch encoder inputs (pad to max length if needed) and run a single forward pass. lengths = [t.numel() for t in encoder_input_ids_list] max_len = max(lengths) if lengths else 0 assert max_len > 0, "Empty encoder_input_ids encountered." @@ -1100,45 +511,47 @@ def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: if len(encoder_input_ids_list) == 1: batch_encoder_input_ids = encoder_input_ids_list[0] elif same_len: - # [1xD]xN =>NxD batch_encoder_input_ids = torch.cat(encoder_input_ids_list, dim=0) else: batch_encoder_input_ids = torch.full( (len(encoder_input_ids_list), max_len), - fill_value=self._pad_id, + fill_value=self.pad_token_id, dtype=encoder_input_ids_list[0].dtype, device=encoder_input_ids_list[0].device, ) for i, t in enumerate(encoder_input_ids_list): batch_encoder_input_ids[i, : t.numel()] = t.squeeze() - # Create (B, T) positions: 0..T-1 for each item. - # batch_encoder_positions = torch.arange( - # max_len, - # dtype=torch.long, - # device=batch_encoder_input_ids.device, - # ).unsqueeze(0).expand(batch_encoder_input_ids.size(0), -1) - - inputs_embeds = self.language_model.model.encoder.embed_tokens(batch_encoder_input_ids) - inputs_embeds = torch.cat([vision_embeddings, inputs_embeds], dim=-2) - batch_encoder_positions = torch.arange( - inputs_embeds.size(1), - dtype=torch.long, - device=inputs_embeds.device, - ).unsqueeze(0).expand(inputs_embeds.size(0), -1) - - # Run encoder once on the batch. + inputs_embeds = self.language_model.model.encoder.embed_tokens( + batch_encoder_input_ids + ) + + # Replace the leading image_token_id placeholders with vision features. + if ( + isinstance(vision_embeddings, torch.Tensor) + and vision_embeddings.numel() > 0 + ): + num_vision = vision_embeddings.size(1) + inputs_embeds = inputs_embeds.clone() + inputs_embeds[:, :num_vision, :] = vision_embeddings + batch_encoder_positions = ( + torch.arange( + inputs_embeds.size(1), + dtype=torch.long, + device=inputs_embeds.device, + ) + .unsqueeze(0) + .expand(inputs_embeds.size(0), -1) + ) + + # Run encoder once on the batch, then split back per item. batch_encoder_output = self.language_model.model.encoder( input_ids=batch_encoder_input_ids, positions=batch_encoder_positions, inputs_embeds=inputs_embeds, ) - # Split back into list[(T, H)] to match expected downstream format. - # If we had to pad, slice back to the original lengths per item. encoder_outputs: list[torch.Tensor] = batch_encoder_output.unbind(dim=0) if not same_len: - encoder_outputs = [ - out[:l] for out, l in zip(encoder_outputs, lengths) - ] + encoder_outputs = [out[:l] for out, l in zip(encoder_outputs, lengths)] return encoder_outputs def forward( @@ -1148,30 +561,18 @@ def forward( intermediate_tensors: IntermediateTensors | None = None, inputs_embeds: torch.Tensor | None = None, encoder_outputs: torch.Tensor | None = None, - # num_encoder_outputs: int | None = None, **kwargs, ) -> torch.Tensor: - r""" - Args: - input_ids - torch.Tensor of *decoder* input token ids. - positions - torch.Tensor of *decoder* position indices. - encoder_input_ids - torch.Tensor of *encoder* input token ids. - encoder_positions - torch.Tensor of *encoder* position indices - Returns: - Output torch.Tensor - """ if encoder_outputs is not None: # Assume same shape for all encoder outputs encoder_outputs = torch.cat(encoder_outputs, dim=0) - hidden_states = self.language_model(input_ids, - positions, - encoder_outputs=encoder_outputs, - inputs_embeds=inputs_embeds) + hidden_states = self.language_model( + input_ids, + positions, + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + ) return hidden_states def compute_logits( @@ -1180,7 +581,11 @@ def compute_logits( ) -> torch.Tensor | None: return self.language_model.compute_logits(hidden_states) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader(self) - return loader.load_weights(weights) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + # pos_idx_to_embed is a register_buffer in the transformers implementation + # (deterministically computed from config), so it has no matching parameter. + loader = AutoWeightsLoader( + self, + ignore_unexpected_suffixes=["visual_temporal_embed.pos_idx_to_embed"], + ) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)