Skip to content

Commit 911cc2d

Browse files
committed
Add Mistral3 multimodal support with Pixtral vision encoder
This adds support for Mistral3 multimodal models (vision + text): - `Bumblebee.Vision.Pixtral`: Pixtral vision encoder with RoPE support - `Bumblebee.Text.Mistral3`: Mistral3 text decoder with interleaved attention - `Bumblebee.Multimodal.Mistral3`: Vision-language model combining Pixtral and Mistral3 with multimodal projector for image-conditioned generation - Ministral/Ministral3 variant support with interleaved attention - Devstral 2 (Ministral3) model support Supported architectures: - PixtralVisionModel - Mistral3Model, Mistral3ForCausalLM, Mistral3ForSequenceClassification - Mistral3ForConditionalGeneration (multimodal) - Ministral3ForCausalLM
1 parent 8365426 commit 911cc2d

10 files changed

Lines changed: 1785 additions & 8 deletions

File tree

lib/bumblebee.ex

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,13 @@ defmodule Bumblebee do
170170
"MistralModel" => {Bumblebee.Text.Mistral, :base},
171171
"MistralForCausalLM" => {Bumblebee.Text.Mistral, :for_causal_language_modeling},
172172
"MistralForSequenceClassification" => {Bumblebee.Text.Mistral, :for_sequence_classification},
173+
"Mistral3Model" => {Bumblebee.Text.Mistral3, :base},
174+
"Mistral3ForCausalLM" => {Bumblebee.Text.Mistral3, :for_causal_language_modeling},
175+
"Mistral3ForSequenceClassification" =>
176+
{Bumblebee.Text.Mistral3, :for_sequence_classification},
177+
"Ministral3ForCausalLM" => {Bumblebee.Text.Mistral3, :for_causal_language_modeling},
178+
"Mistral3ForConditionalGeneration" =>
179+
{Bumblebee.Multimodal.Mistral3, :for_conditional_generation},
173180
"PhiModel" => {Bumblebee.Text.Phi, :base},
174181
"PhiForCausalLM" => {Bumblebee.Text.Phi, :for_causal_language_modeling},
175182
"PhiForSequenceClassification" => {Bumblebee.Text.Phi, :for_sequence_classification},
@@ -198,6 +205,7 @@ defmodule Bumblebee do
198205
"T5Model" => {Bumblebee.Text.T5, :base},
199206
"T5ForConditionalGeneration" => {Bumblebee.Text.T5, :for_conditional_generation},
200207
"T5EncoderModel" => {Bumblebee.Text.T5, :encoder},
208+
"PixtralVisionModel" => {Bumblebee.Vision.Pixtral, :base},
201209
"ViTForImageClassification" => {Bumblebee.Vision.Vit, :for_image_classification},
202210
"ViTForMaskedImageModeling" => {Bumblebee.Vision.Vit, :for_masked_image_modeling},
203211
"ViTModel" => {Bumblebee.Vision.Vit, :base},
@@ -255,6 +263,8 @@ defmodule Bumblebee do
255263
"layoutlm" => :layout_lm,
256264
"llama" => :llama,
257265
"mistral" => :llama,
266+
"mistral3" => :llama,
267+
"ministral3" => :llama,
258268
"mbart" => :mbart,
259269
"phi" => :code_gen,
260270
"phi3" => :llama,

lib/bumblebee/layers/transformer.ex

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,12 @@ defmodule Bumblebee.Layers.Transformer do
2525
- a keyword list (applied to all blocks)
2626
- a function that takes the block index and returns the configuration
2727
28+
* `:attention_window_size` - window size for sliding attention. Can be:
29+
- a tuple `{left_size, right_size}` (applied to all blocks)
30+
- a function that takes the block index and returns the configuration
31+
(useful for interleaved attention patterns)
32+
- `nil` for global attention
33+
2834
* `:name` - the prefix for layer names
2935
3036
For all other options (including required options) see `block/2`.
@@ -52,7 +58,6 @@ defmodule Bumblebee.Layers.Transformer do
5258
:output_use_bias,
5359
:layer_norm,
5460
:block_type,
55-
:attention_window_size,
5661
:scale_attention_weights
5762
]
5863

@@ -64,6 +69,7 @@ defmodule Bumblebee.Layers.Transformer do
6469
:name,
6570
:num_blocks,
6671
:rotary_embedding,
72+
:attention_window_size,
6773
attention_mask: Layers.none(),
6874
attention_head_mask: Layers.none(),
6975
attention_relative_bias: nil,
@@ -85,6 +91,7 @@ defmodule Bumblebee.Layers.Transformer do
8591
cross_attention_head_mask = opts[:cross_attention_head_mask]
8692
cache = opts[:cache]
8793
rotary_embedding = opts[:rotary_embedding]
94+
attention_window_size = opts[:attention_window_size]
8895

8996
block_opts = Keyword.take(opts, block_opts_keys)
9097

@@ -121,6 +128,13 @@ defmodule Bumblebee.Layers.Transformer do
121128
config when is_list(config) -> config
122129
end
123130

131+
block_attention_window_size =
132+
case attention_window_size do
133+
nil -> nil
134+
fun when is_function(fun, 1) -> fun.(idx)
135+
config -> config
136+
end
137+
124138
{hidden_state, attention, cross_attention, block_cache, attention_relative_bias} =
125139
block(
126140
state.hidden_state,
@@ -134,6 +148,7 @@ defmodule Bumblebee.Layers.Transformer do
134148
block_cache: block_cache,
135149
offset: offset,
136150
rotary_embedding: block_rotary_embedding,
151+
attention_window_size: block_attention_window_size,
137152
name: join(name, idx)
138153
] ++ block_opts
139154
)

0 commit comments

Comments
 (0)