From ba6d493eaf5f1dfddccd2d691e8252f31404bc91 Mon Sep 17 00:00:00 2001 From: edenamram <74473448+edenamram@users.noreply.github.com> Date: Tue, 9 Jun 2026 18:42:02 +0300 Subject: [PATCH 1/7] Add Gemma4 model implementation and update model mappings --- lib/bumblebee.ex | 3 + lib/bumblebee/text/gemma4_text.ex | 621 ++++++++++++++++++++++++++++++ 2 files changed, 624 insertions(+) create mode 100644 lib/bumblebee/text/gemma4_text.ex diff --git a/lib/bumblebee.ex b/lib/bumblebee.ex index 6237dc61..17c46fe2 100644 --- a/lib/bumblebee.ex +++ b/lib/bumblebee.ex @@ -142,6 +142,8 @@ defmodule Bumblebee do "Gemma3TextForCausalLM" => {Bumblebee.Text.Gemma3Text, :for_causal_language_modeling}, "Gemma3TextForSequenceClassification" => {Bumblebee.Text.Gemma3Text, :for_sequence_classification}, + "Gemma4ForConditionalGeneration" => + {Bumblebee.Text.Gemma4Text, :for_causal_language_modeling}, "GPT2ForSequenceClassification" => {Bumblebee.Text.Gpt2, :for_sequence_classification}, "GPT2ForTokenClassification" => {Bumblebee.Text.Gpt2, :for_token_classification}, "GPT2LMHeadModel" => {Bumblebee.Text.Gpt2, :for_causal_language_modeling}, @@ -273,6 +275,7 @@ defmodule Bumblebee do "clip" => :clip, "gemma" => :gemma, "gemma3_text" => :gemma, + "gemma4" => :gemma, "gpt_neox" => :gpt_neo_x, "gpt2" => :gpt2, "gpt_bigcode" => :gpt2, diff --git a/lib/bumblebee/text/gemma4_text.ex b/lib/bumblebee/text/gemma4_text.ex new file mode 100644 index 00000000..72a3bbd5 --- /dev/null +++ b/lib/bumblebee/text/gemma4_text.ex @@ -0,0 +1,621 @@ +defmodule Bumblebee.Text.Gemma4Text do + alias Bumblebee.Shared + + options = + [ + vocab_size: [ + default: 262_144, + doc: """ + the vocabulary size of the token embedding. This corresponds to the number of distinct + tokens that can be represented in model input and output + """ + ], + max_positions: [ + default: 131_072, + doc: """ + the maximum sequence length that this model can process + """ + ], + hidden_size: [ + default: 2304, + doc: "the dimensionality of hidden layers" + ], + intermediate_size: [ + default: 9216, + doc: "the dimensionality of intermediate layers" + ], + attention_head_size: [ + default: 256, + doc: + "the size of the key, value, and query projection per attention head for sliding attention layers" + ], + global_attention_head_size: [ + default: 512, + doc: + "the size of the key, value, and query projection per attention head for global (full) attention layers" + ], + num_blocks: [ + default: 30, + doc: "the number of Transformer blocks in the model" + ], + num_attention_heads: [ + default: 8, + doc: "the number of attention heads for each attention layer in the model" + ], + num_key_value_heads: [ + default: 4, + doc: "the number of key value heads for each attention layer in the model" + ], + num_global_key_value_heads: [ + default: nil, + doc: """ + the number of key value heads for global (full) attention layers. + If nil, defaults to num_key_value_heads. + """ + ], + activation: [ + default: :gelu_approx_tanh, + doc: "the activation function" + ], + rotary_embedding_base: [ + default: 1_000_000, + doc: "base for computing rotary embedding frequency for global attention layers" + ], + rotary_embedding_base_local: [ + default: 10_000, + doc: "base for computing rotary embedding frequency for local (sliding) attention layers" + ], + partial_rotary_factor: [ + default: 1.0, + doc: """ + the fraction of head dimensions to apply rotary embeddings to in global attention layers. + Sliding attention layers always use full rotation (1.0). + Extracted from rope_parameters.full_attention.partial_rotary_factor. + """ + ], + use_attention_bias: [ + default: false, + doc: + "whether or not to use bias in the query, key, value, and output projections in attention layers" + ], + layer_norm_epsilon: [ + default: 1.0e-6, + doc: "the epsilon used by RMS normalization layers" + ], + initializer_scale: [ + default: 0.02, + doc: + "the standard deviation of the normal initializer used for initializing kernel parameters" + ], + attention_window_size: [ + default: 512, + doc: + "window size for both sides of the sliding attention window (used for `:sliding_attention` layers)" + ], + layer_types: [ + default: nil, + doc: """ + a list of layer types for each layer, where each element is either `:sliding_attention` + (local attention with sliding window) or `:full_attention` (global attention) + """ + ], + tie_word_embeddings: [ + default: true, + doc: "whether to tie input and output embedding weights" + ], + final_logit_softcapping: [ + default: nil, + doc: """ + if set, logits are capped using `tanh` to this value before the final softmax. + This prevents extreme logit values from dominating the output distribution. + Logits are scaled by tanh(logit / cap) * cap. + """ + ], + hidden_size_per_layer_input: [ + default: 256, + doc: """ + the dimensionality of the per-layer input embeddings (PLE). Each transformer layer + gets its own small embedding that is added to the main hidden state. + """ + ], + vocab_size_per_layer_input: [ + default: 262_144, + doc: "the vocabulary size for per-layer input embeddings" + ], + num_kv_shared_layers: [ + default: 0, + doc: """ + the number of consecutive decoder layers that share the same key-value projections. + A value of 0 means no sharing (each layer has independent KV projections). + """ + ], + use_double_wide_mlp: [ + default: false, + doc: """ + whether to use a double-width MLP with fused gate and up projections. + When true, the gate and up projections are doubled in size. + """ + ] + ] ++ + Shared.common_options([:num_labels, :id_to_label]) ++ Shared.token_options(pad_token_id: 0) + + @moduledoc """ + Gemma 4 model family (text backbone). + + ## Global layer options + + #{Shared.global_layer_options_doc([:output_hidden_states, :output_attentions])} + + ## Configuration + + #{Shared.options_doc(options)} + """ + + defstruct [architecture: :base] ++ Shared.option_defaults(options) + + @behaviour Bumblebee.ModelSpec + @behaviour Bumblebee.Configurable + @behaviour Bumblebee.Text.Generation + + import Bumblebee.Utils.Model, only: [join: 2] + + alias Bumblebee.Layers + + @impl true + def architectures(), + do: [ + :base, + :for_causal_language_modeling, + :for_sequence_classification + ] + + @impl true + def config(spec, opts) do + spec + |> Shared.put_config_attrs(opts) + |> Shared.validate_label_options() + end + + @impl true + def input_template(_spec) do + %{ + "input_ids" => Nx.template({1, 1}, :s64) + } + end + + @impl true + def init_cache(spec, batch_size, max_length, _inputs) do + Layers.Decoder.init_cache(batch_size, max_length, + hidden_size: spec.hidden_size, + attention_head_size: spec.attention_head_size, + decoder_num_attention_heads: spec.num_attention_heads, + decoder_num_blocks: spec.num_blocks + ) + end + + @impl true + def traverse_cache(_spec, cache, fun) do + Layers.Decoder.traverse_cache(cache, fun) + end + + @impl true + def model(%__MODULE__{architecture: :base} = spec) do + inputs = inputs(spec) + + inputs + |> core(spec) + |> Layers.output() + end + + def model(%__MODULE__{architecture: :for_causal_language_modeling} = spec) do + inputs = inputs(spec) + + outputs = core(inputs, spec) + logits = language_modeling_head(outputs.hidden_state, spec, name: "language_modeling_head") + + Layers.output(%{ + logits: logits, + hidden_states: outputs.hidden_states, + attentions: outputs.attentions, + cache: outputs.cache + }) + end + + def model(%__MODULE__{architecture: :for_sequence_classification} = spec) do + inputs = inputs(spec) + + outputs = core(inputs, spec) + + logits = + Axon.dense(outputs.hidden_state, spec.num_labels, + kernel_initializer: kernel_initializer(spec), + name: "sequence_classification_head.output", + use_bias: false + ) + + pooled_logits = + Layers.if_present inputs["input_ids"] do + Axon.layer( + fn logits, input_ids, _opts -> + indices = + input_ids + |> Nx.not_equal(spec.pad_token_id) + |> Nx.sum(axes: [-1]) + |> Nx.subtract(1) + |> Nx.as_type({:s, 64}) + + Bumblebee.Utils.Nx.batched_take(logits, indices) + end, + [logits, inputs["input_ids"]] + ) + else + Layers.take_token(logits, axis: 1, index: -1) + end + + Layers.output(%{ + logits: pooled_logits, + hidden_states: outputs.hidden_states, + attentions: outputs.attentions, + cache: outputs.cache + }) + end + + defp inputs(spec) do + shape = {nil, nil} + hidden_shape = {nil, nil, spec.hidden_size} + + attention_head_mask_shape = {spec.num_blocks, spec.num_attention_heads} + + Bumblebee.Utils.Model.inputs_to_map([ + Axon.input("input_ids", optional: true, shape: shape), + Axon.input("attention_mask", optional: true, shape: shape), + Axon.input("position_ids", optional: true, shape: shape), + Axon.input("attention_head_mask", optional: true, shape: attention_head_mask_shape), + Axon.input("input_embeddings", optional: true, shape: hidden_shape), + Axon.input("cache", optional: true) + ]) + end + + defp core(inputs, spec) do + embeddings = + embedder( + inputs["input_ids"], + inputs["input_embeddings"], + spec, + name: "embedder" + ) + + position_ids = + Layers.default inputs["position_ids"] do + Layers.default_position_ids(embeddings) + end + + decoder_outputs = + decoder( + embeddings, + position_ids, + inputs["attention_mask"], + inputs["attention_head_mask"], + inputs["cache"], + spec, + name: "decoder" + ) + + hidden_state = + Layers.rms_norm(decoder_outputs.hidden_state, + name: "output_norm", + shift: 1.0, + epsilon: spec.layer_norm_epsilon, + upcast: :all + ) + + %{ + hidden_state: hidden_state, + hidden_states: Layers.append(decoder_outputs.hidden_states, hidden_state), + attentions: decoder_outputs.attentions, + cache: decoder_outputs.cache + } + end + + defp embedder(input_ids, input_embeddings, spec, opts) do + name = opts[:name] + + Layers.default input_embeddings do + Axon.embedding(input_ids, spec.vocab_size, spec.hidden_size, + kernel_initializer: kernel_initializer(spec), + name: join(name, "token_embedding") + ) + end + |> Axon.nx(fn x -> + normalization_factor = + spec.hidden_size + |> Nx.tensor(type: Nx.type(x)) + |> Nx.sqrt() + + Nx.multiply(x, normalization_factor) + end) + end + + defp decoder( + hidden_state, + position_ids, + attention_mask, + attention_head_mask, + cache, + spec, + opts + ) do + name = opts[:name] + + query_norm = &Layers.rms_norm(&1, shift: 1.0, epsilon: spec.layer_norm_epsilon, name: &2) + key_norm = &Layers.rms_norm(&1, shift: 1.0, epsilon: spec.layer_norm_epsilon, name: &2) + + layer_types = spec.layer_types || generate_layer_types(spec.num_blocks) + + attention_window_size = fn idx -> + case Enum.at(layer_types, idx, :sliding_attention) do + :full_attention -> nil + :sliding_attention -> {spec.attention_window_size, spec.attention_window_size} + end + end + + rotary_embedding = fn idx -> + base = + case Enum.at(layer_types, idx, :sliding_attention) do + :full_attention -> spec.rotary_embedding_base + :sliding_attention -> spec.rotary_embedding_base_local + end + + [ + position_ids: position_ids, + max_positions: spec.max_positions, + base: base + ] + end + + attention_scale = :math.pow(spec.attention_head_size, -0.5) + + Layers.Transformer.blocks(hidden_state, + attention_mask: attention_mask, + attention_head_mask: attention_head_mask, + cache: cache, + num_blocks: spec.num_blocks, + num_attention_heads: spec.num_attention_heads, + num_key_value_heads: spec.num_key_value_heads, + hidden_size: spec.hidden_size, + attention_head_size: spec.attention_head_size, + attention_scale: attention_scale, + kernel_initializer: kernel_initializer(spec), + layer_norm: + &Layers.rms_norm(&1, + shift: 1.0, + name: &2, + epsilon: spec.layer_norm_epsilon, + upcast: :all + ), + ffn: + &gated_ffn(&1, spec.intermediate_size, spec.hidden_size, + name: &2, + activation: spec.activation + ), + block_type: &gemma4_block_impl(&1, &2, &3, spec), + causal: true, + rotary_embedding: rotary_embedding, + attention_window_size: attention_window_size, + query_norm: query_norm, + key_norm: key_norm, + query_use_bias: spec.use_attention_bias, + key_use_bias: spec.use_attention_bias, + value_use_bias: spec.use_attention_bias, + output_use_bias: spec.use_attention_bias, + name: join(name, "blocks") + ) + end + + # Custom block implementation for Gemma 4's normalization structure: + # - Post-attention norm BEFORE residual add + # - Pre/post FFN norms + defp gemma4_block_impl(hidden_state, steps, name, spec) do + shortcut = hidden_state + + {hidden_state, attention_info} = + hidden_state + |> steps.self_attention_norm.() + |> steps.self_attention.() + + hidden_state = + Layers.rms_norm(hidden_state, + shift: 1.0, + name: join(name, "post_attention_norm"), + epsilon: spec.layer_norm_epsilon, + upcast: :all + ) + + hidden_state = Axon.add(shortcut, hidden_state) + + shortcut = hidden_state + + hidden_state = + Layers.rms_norm(hidden_state, + shift: 1.0, + name: join(name, "pre_ffn_norm"), + epsilon: spec.layer_norm_epsilon, + upcast: :all + ) + + hidden_state = steps.ffn.(hidden_state) + + hidden_state = + Layers.rms_norm(hidden_state, + shift: 1.0, + name: join(name, "post_ffn_norm"), + epsilon: spec.layer_norm_epsilon, + upcast: :all + ) + + hidden_state = Axon.add(shortcut, hidden_state) + + # Handle cross-attention (required by block interface but not used by Gemma 4) + {_hidden_state, cross_attention_info} = + steps.cross_attention_maybe.(hidden_state, fn _ -> + raise "cross attention not supported" + end) + + {hidden_state, attention_info, cross_attention_info} + end + + defp gated_ffn(hidden_state, intermediate_size, output_size, opts) do + name = opts[:name] + activation = opts[:activation] + + intermediate = + Axon.dense(hidden_state, intermediate_size, + name: join(name, "intermediate"), + use_bias: false + ) + + gate = Axon.dense(hidden_state, intermediate_size, name: join(name, "gate"), use_bias: false) + + hidden_state = Axon.multiply(intermediate, Layers.activation(gate, activation)) + + Axon.dense(hidden_state, output_size, name: join(name, "output"), use_bias: false) + end + + defp language_modeling_head(hidden_state, spec, opts) do + name = opts[:name] + + Layers.dense_transposed(hidden_state, spec.vocab_size, + kernel_initializer: kernel_initializer(spec), + name: join(name, "output") + ) + end + + defp kernel_initializer(spec) do + Axon.Initializers.normal(scale: spec.initializer_scale) + end + + # Generate layer_types fallback: every 5th layer uses full attention + defp generate_layer_types(num_blocks) do + Enum.map(0..(num_blocks - 1), fn i -> + if rem(i + 1, 5) == 0 do + :full_attention + else + :sliding_attention + end + end) + end + + defimpl Bumblebee.HuggingFace.Transformers.Config do + def load(spec, data) do + import Shared.Converters + + data = data["text_config"] || data + rope_params = data["rope_parameters"] || %{} + full_attention_rope = rope_params["full_attention"] || %{} + sliding_attention_rope = rope_params["sliding_attention"] || %{} + + data = + data + |> Map.put_new("rope_theta", full_attention_rope["rope_theta"] || 1_000_000) + |> Map.put_new("rope_local_base_freq", sliding_attention_rope["rope_theta"] || 10_000) + |> Map.put_new( + "partial_rotary_factor", + full_attention_rope["partial_rotary_factor"] || 1.0 + ) + + opts = + convert!(data, + vocab_size: {"vocab_size", number()}, + max_positions: {"max_position_embeddings", number()}, + hidden_size: {"hidden_size", number()}, + num_blocks: {"num_hidden_layers", number()}, + num_attention_heads: {"num_attention_heads", number()}, + num_key_value_heads: {"num_key_value_heads", number()}, + num_global_key_value_heads: {"num_global_key_value_heads", optional(number())}, + attention_head_size: {"head_dim", number()}, + global_attention_head_size: {"global_head_dim", number()}, + intermediate_size: {"intermediate_size", number()}, + activation: {"hidden_activation", activation()}, + use_attention_bias: {"attention_bias", boolean()}, + rotary_embedding_base: {"rope_theta", number()}, + rotary_embedding_base_local: {"rope_local_base_freq", number()}, + partial_rotary_factor: {"partial_rotary_factor", number()}, + initializer_scale: {"initializer_range", number()}, + layer_norm_epsilon: {"rms_norm_eps", number()}, + attention_window_size: {"sliding_window", optional(number())}, + layer_types: + {"layer_types", + list( + mapping(%{ + "sliding_attention" => :sliding_attention, + "full_attention" => :full_attention + }) + )}, + tie_word_embeddings: {"tie_word_embeddings", boolean()}, + final_logit_softcapping: {"final_logit_softcapping", optional(number())}, + hidden_size_per_layer_input: {"hidden_size_per_layer_input", number()}, + vocab_size_per_layer_input: {"vocab_size_per_layer_input", number()}, + num_kv_shared_layers: {"num_kv_shared_layers", number()}, + use_double_wide_mlp: {"use_double_wide_mlp", boolean()} + ) ++ Shared.common_options_from_transformers(data, spec) + + @for.config(spec, opts) + end + end + + defimpl Bumblebee.HuggingFace.Transformers.Model do + def params_mapping(spec) do + %{ + "embedder.token_embedding" => "model.language_model.embed_tokens", + # PLE (Per-Layer Embeddings) global weights + "embedder.token_embedding_per_layer" => "model.language_model.embed_tokens_per_layer", + "per_layer_model_projection" => "model.language_model.per_layer_model_projection", + "per_layer_projection_norm" => "model.language_model.per_layer_projection_norm", + # PLE per-layer weights + "decoder.blocks.{n}.layer_scalar" => "model.language_model.layers.{n}.layer_scalar", + "decoder.blocks.{n}.per_layer_input_gate" => + "model.language_model.layers.{n}.per_layer_input_gate", + "decoder.blocks.{n}.per_layer_projection" => + "model.language_model.layers.{n}.per_layer_projection", + "decoder.blocks.{n}.post_per_layer_input_norm" => + "model.language_model.layers.{n}.post_per_layer_input_norm", + # Attention projections + "decoder.blocks.{n}.self_attention.query" => + "model.language_model.layers.{n}.self_attn.q_proj", + "decoder.blocks.{n}.self_attention.key" => + "model.language_model.layers.{n}.self_attn.k_proj", + "decoder.blocks.{n}.self_attention.value" => + "model.language_model.layers.{n}.self_attn.v_proj", + "decoder.blocks.{n}.self_attention.output" => + "model.language_model.layers.{n}.self_attn.o_proj", + # QK-norm + "decoder.blocks.{n}.self_attention.query_norm" => + "model.language_model.layers.{n}.self_attn.q_norm", + "decoder.blocks.{n}.self_attention.key_norm" => + "model.language_model.layers.{n}.self_attn.k_norm", + # Layer norms + "decoder.blocks.{n}.self_attention_norm" => + "model.language_model.layers.{n}.input_layernorm", + "decoder.blocks.{n}.post_attention_norm" => + "model.language_model.layers.{n}.post_attention_layernorm", + # FFN layer norms + "decoder.blocks.{n}.pre_ffn_norm" => + "model.language_model.layers.{n}.pre_feedforward_layernorm", + "decoder.blocks.{n}.post_ffn_norm" => + "model.language_model.layers.{n}.post_feedforward_layernorm", + # FFN projections + "decoder.blocks.{n}.ffn.gate" => "model.language_model.layers.{n}.mlp.gate_proj", + "decoder.blocks.{n}.ffn.intermediate" => "model.language_model.layers.{n}.mlp.up_proj", + "decoder.blocks.{n}.ffn.output" => "model.language_model.layers.{n}.mlp.down_proj", + # Output + "output_norm" => "model.language_model.norm", + "language_modeling_head.output" => + if(spec.tie_word_embeddings, + do: "model.language_model.embed_tokens", + else: "lm_head" + ), + "sequence_classification_head.output" => "score" + } + end + end +end From 2c21e11d6fd83cbd8e2209791e6a949660ac717e Mon Sep 17 00:00:00 2001 From: edenamram <74473448+edenamram@users.noreply.github.com> Date: Sat, 13 Jun 2026 19:14:20 +0300 Subject: [PATCH 2/7] Gemma4Text to enhance cache initialization and attention head size handling --- lib/bumblebee/layers/transformer.ex | 11 +++- lib/bumblebee/text/gemma4_text.ex | 79 +++++++++++++++++++++++------ 2 files changed, 74 insertions(+), 16 deletions(-) diff --git a/lib/bumblebee/layers/transformer.ex b/lib/bumblebee/layers/transformer.ex index 188b0ffe..e7b18496 100644 --- a/lib/bumblebee/layers/transformer.ex +++ b/lib/bumblebee/layers/transformer.ex @@ -52,7 +52,6 @@ defmodule Bumblebee.Layers.Transformer do :hidden_size, :ffn, :kernel_initializer, - :attention_head_size, :dropout_rate, :attention_dropout_rate, :query_use_bias, @@ -75,6 +74,7 @@ defmodule Bumblebee.Layers.Transformer do :num_blocks, :rotary_embedding, :attention_window_size, + :attention_head_size, attention_mask: Layers.none(), attention_head_mask: Layers.none(), attention_relative_bias: nil, @@ -97,6 +97,7 @@ defmodule Bumblebee.Layers.Transformer do cache = opts[:cache] rotary_embedding = opts[:rotary_embedding] attention_window_size = opts[:attention_window_size] + attention_head_size = opts[:attention_head_size] block_opts = Keyword.take(opts, block_opts_keys) @@ -142,12 +143,20 @@ defmodule Bumblebee.Layers.Transformer do size -> size end + block_attention_head_size = + case attention_head_size do + nil -> nil + fun when is_function(fun, 1) -> fun.(idx) + size -> size + end + {hidden_state, attention, cross_attention, block_cache, attention_relative_bias} = block( state.hidden_state, [ attention_mask: attention_mask, attention_head_mask: block_attention_head_mask, + attention_head_size: block_attention_head_size, attention_relative_bias: attention_relative_bias, cross_hidden_state: cross_hidden_state, cross_attention_mask: cross_attention_mask, diff --git a/lib/bumblebee/text/gemma4_text.ex b/lib/bumblebee/text/gemma4_text.ex index 72a3bbd5..7d94d54e 100644 --- a/lib/bumblebee/text/gemma4_text.ex +++ b/lib/bumblebee/text/gemma4_text.ex @@ -185,12 +185,27 @@ defmodule Bumblebee.Text.Gemma4Text do @impl true def init_cache(spec, batch_size, max_length, _inputs) do - Layers.Decoder.init_cache(batch_size, max_length, - hidden_size: spec.hidden_size, - attention_head_size: spec.attention_head_size, - decoder_num_attention_heads: spec.num_attention_heads, - decoder_num_blocks: spec.num_blocks - ) + layer_types = spec.layer_types || generate_layer_types(spec.num_blocks) + + blocks = + Enum.map(0..(spec.num_blocks - 1), fn idx -> + head_size = + case Enum.at(layer_types, idx, :sliding_attention) do + :full_attention -> spec.global_attention_head_size + :sliding_attention -> spec.attention_head_size + end + + shape = {batch_size, max_length, spec.num_attention_heads, head_size} + zeros = Nx.broadcast(0.0, shape) + self_attention = %{key: zeros, value: zeros} + + %{self_attention: self_attention, cross_attention: %Axon.None{}} + end) + |> List.to_tuple() + + offset = Nx.tensor(0) + attention_mask = Nx.broadcast(0, {batch_size, max_length}) + %{blocks: blocks, offset: offset, attention_mask: attention_mask} end @impl true @@ -374,6 +389,7 @@ defmodule Bumblebee.Text.Gemma4Text do end attention_scale = :math.pow(spec.attention_head_size, -0.5) + non_double_wide_count = spec.num_blocks - spec.num_kv_shared_layers Layers.Transformer.blocks(hidden_state, attention_mask: attention_mask, @@ -383,7 +399,12 @@ defmodule Bumblebee.Text.Gemma4Text do num_attention_heads: spec.num_attention_heads, num_key_value_heads: spec.num_key_value_heads, hidden_size: spec.hidden_size, - attention_head_size: spec.attention_head_size, + attention_head_size: fn idx -> + case Enum.at(layer_types, idx, :sliding_attention) do + :full_attention -> spec.global_attention_head_size + :sliding_attention -> spec.attention_head_size + end + end, attention_scale: attention_scale, kernel_initializer: kernel_initializer(spec), layer_norm: @@ -393,11 +414,25 @@ defmodule Bumblebee.Text.Gemma4Text do epsilon: spec.layer_norm_epsilon, upcast: :all ), - ffn: - &gated_ffn(&1, spec.intermediate_size, spec.hidden_size, - name: &2, + ffn: fn hidden_state, ffn_name -> + idx = + ffn_name + |> String.split(".") + |> Enum.at(2) + |> String.to_integer() + + intermediate_size = + if spec.use_double_wide_mlp and idx >= non_double_wide_count do + spec.intermediate_size * 2 + else + spec.intermediate_size + end + + gated_ffn(hidden_state, intermediate_size, spec.hidden_size, + name: ffn_name, activation: spec.activation - ), + ) + end, block_type: &gemma4_block_impl(&1, &2, &3, spec), causal: true, rotary_embedding: rotary_embedding, @@ -484,10 +519,24 @@ defmodule Bumblebee.Text.Gemma4Text do defp language_modeling_head(hidden_state, spec, opts) do name = opts[:name] - Layers.dense_transposed(hidden_state, spec.vocab_size, - kernel_initializer: kernel_initializer(spec), - name: join(name, "output") - ) + logits = + Layers.dense_transposed(hidden_state, spec.vocab_size, + kernel_initializer: kernel_initializer(spec), + name: join(name, "output") + ) + + cap = spec.final_logit_softcapping + + if cap do + Axon.nx(logits, fn x -> + x + |> Nx.divide(cap) + |> Nx.tanh() + |> Nx.multiply(cap) + end) + else + logits + end end defp kernel_initializer(spec) do From 41a2d034c176b7d0fa905e1de2df7a4b3bc75f19 Mon Sep 17 00:00:00 2001 From: edenamram <74473448+edenamram@users.noreply.github.com> Date: Sat, 13 Jun 2026 19:20:51 +0300 Subject: [PATCH 3/7] Implement per-layer input handling in Gemma4 model and add ChunkedFileTensor for large tensor support --- lib/bumblebee.ex | 6 +- lib/bumblebee/conversion/pytorch_params.ex | 47 ++++++++- lib/bumblebee/text/gemma4_text.ex | 117 ++++++++++++++++++++- lib/bumblebee/utils/chunked_file_tensor.ex | 71 +++++++++++++ 4 files changed, 235 insertions(+), 6 deletions(-) create mode 100644 lib/bumblebee/utils/chunked_file_tensor.ex diff --git a/lib/bumblebee.ex b/lib/bumblebee.ex index 17c46fe2..d10a56e1 100644 --- a/lib/bumblebee.ex +++ b/lib/bumblebee.ex @@ -780,11 +780,15 @@ defmodule Bumblebee do end defp params_file_loader_fun(".safetensors", opts) do - opts[:safetensors_reader] || (&Safetensors.read!(&1, lazy: true)) + opts[:safetensors_reader] || (&read_safetensors_chunked/1) end defp params_file_loader_fun(_, _opts), do: &Bumblebee.Conversion.PyTorchLoader.load!/1 + defp read_safetensors_chunked(path) do + Safetensors.read!(path, lazy: true) + end + @doc """ Featurizes `input` with the given featurizer. diff --git a/lib/bumblebee/conversion/pytorch_params.ex b/lib/bumblebee/conversion/pytorch_params.ex index c17fef85..0576fe8b 100644 --- a/lib/bumblebee/conversion/pytorch_params.ex +++ b/lib/bumblebee/conversion/pytorch_params.ex @@ -150,7 +150,7 @@ defmodule Bumblebee.Conversion.PyTorchParams do {value, diff} = if all_sources_found? do - source_values = Enum.map(source_values, &Nx.to_tensor/1) + source_values = Enum.map(source_values, &lazy_to_tensor/1) value = builder_fun.(Enum.reverse(source_values)) case verify_param_shape(param_expr, value) do @@ -188,6 +188,51 @@ defmodule Bumblebee.Conversion.PyTorchParams do defp prepend(diff, key, values), do: Map.update!(diff, key, &(values ++ &1)) + # macOS pread(2) returns EINVAL when byte count > INT_MAX (~2 GB). + # For large safetensors tensors, read in 1 GB chunks instead. + @pread_chunk 1_073_741_824 + + defp lazy_to_tensor(%Safetensors.FileTensor{byte_size: size} = ft) + when size > @pread_chunk do + # Force BinaryBackend: the GPU backend (EMLX) cannot allocate tensors + # this large in a single call, and we must also avoid the macOS pread + # INT_MAX limit by reading in chunks. + Nx.with_default_backend(Nx.BinaryBackend, fn -> + File.open!(ft.path, [:read, :raw], fn file -> + binary = pread_chunked(file, ft.byte_offset, ft.byte_size) + Safetensors.Shared.build_tensor(binary, ft.shape, ft.type) + end) + end) + end + + defp lazy_to_tensor(value), do: Nx.to_tensor(value) + + defp pread_chunked(file, offset, size) when size <= @pread_chunk do + {:ok, binary} = :file.pread(file, offset, size) + binary + end + + defp pread_chunked(file, offset, size) do + full = div(size, @pread_chunk) + rest = rem(size, @pread_chunk) + + chunks = + for i <- 0..(full - 1) do + {:ok, chunk} = :file.pread(file, offset + i * @pread_chunk, @pread_chunk) + chunk + end + + chunks = + if rest > 0 do + {:ok, tail} = :file.pread(file, offset + full * @pread_chunk, rest) + chunks ++ [tail] + else + chunks + end + + IO.iodata_to_binary(chunks) + end + defp infer_prefixes(layers, pytorch_state, params_mapping) do # Note: target refers to the parameters we are initializing, while # source refers to the state we are loading from diff --git a/lib/bumblebee/text/gemma4_text.ex b/lib/bumblebee/text/gemma4_text.ex index 7d94d54e..45993336 100644 --- a/lib/bumblebee/text/gemma4_text.ex +++ b/lib/bumblebee/text/gemma4_text.ex @@ -305,6 +305,14 @@ defmodule Bumblebee.Text.Gemma4Text do Layers.default_position_ids(embeddings) end + # PLE: compute per-layer inputs + per_layer_inputs = + if spec.hidden_size_per_layer_input do + compute_per_layer_inputs(inputs["input_ids"], embeddings, spec) + else + nil + end + decoder_outputs = decoder( embeddings, @@ -313,6 +321,7 @@ defmodule Bumblebee.Text.Gemma4Text do inputs["attention_head_mask"], inputs["cache"], spec, + per_layer_inputs: per_layer_inputs, name: "decoder" ) @@ -332,6 +341,60 @@ defmodule Bumblebee.Text.Gemma4Text do } end + defp compute_per_layer_inputs(input_ids, embeddings, spec) do + ple_dim = spec.hidden_size_per_layer_input + num_layers = spec.num_blocks + total_ple_dim = num_layers * ple_dim + + # Token-identity: lookup in per-layer embedding table + token_identity = + Axon.embedding(input_ids, spec.vocab_size_per_layer_input, total_ple_dim, + name: "embedder.token_embedding_per_layer" + ) + |> Axon.nx(fn x -> + # Scale by sqrt(ple_dim) + scale = Nx.tensor(ple_dim, type: Nx.type(x)) |> Nx.sqrt() + x = Nx.multiply(x, scale) + # Reshape from [B, S, num_layers * ple_dim] to [B, S, num_layers, ple_dim] + shape = Nx.shape(x) + batch = elem(shape, 0) + seq = elem(shape, 1) + Nx.reshape(x, {batch, seq, num_layers, ple_dim}) + end) + + # Context-aware: project main embeddings + # Norm is applied before reshape so its weight has shape [num_layers * ple_dim], + # matching HuggingFace's per_layer_projection_norm weight. + context_aware = + Axon.dense(embeddings, total_ple_dim, + name: "per_layer_model_projection", + use_bias: false + ) + |> Layers.rms_norm( + name: "per_layer_projection_norm", + epsilon: spec.layer_norm_epsilon + ) + |> Axon.nx(fn x -> + # Scale by 1/sqrt(hidden_size) + scale = Nx.divide(1.0, Nx.sqrt(Nx.tensor(spec.hidden_size, type: Nx.type(x)))) + x = Nx.multiply(x, scale) + # Reshape to [B, S, num_layers, ple_dim] + shape = Nx.shape(x) + batch = elem(shape, 0) + seq = elem(shape, 1) + Nx.reshape(x, {batch, seq, num_layers, ple_dim}) + end) + + # Combine: (token_identity + context_aware) * (1/sqrt(2)) + Axon.layer( + fn token_id, context, _opts -> + inv_sqrt2 = Nx.tensor(1.0 / :math.sqrt(2), type: Nx.type(token_id)) + Nx.multiply(Nx.add(token_id, context), inv_sqrt2) + end, + [token_identity, context_aware] + ) + end + defp embedder(input_ids, input_embeddings, spec, opts) do name = opts[:name] @@ -361,6 +424,7 @@ defmodule Bumblebee.Text.Gemma4Text do opts ) do name = opts[:name] + per_layer_inputs = opts[:per_layer_inputs] query_norm = &Layers.rms_norm(&1, shift: 1.0, epsilon: spec.layer_norm_epsilon, name: &2) key_norm = &Layers.rms_norm(&1, shift: 1.0, epsilon: spec.layer_norm_epsilon, name: &2) @@ -433,7 +497,7 @@ defmodule Bumblebee.Text.Gemma4Text do activation: spec.activation ) end, - block_type: &gemma4_block_impl(&1, &2, &3, spec), + block_type: &gemma4_block_impl(&1, &2, &3, spec, per_layer_inputs), causal: true, rotary_embedding: rotary_embedding, attention_window_size: attention_window_size, @@ -450,7 +514,54 @@ defmodule Bumblebee.Text.Gemma4Text do # Custom block implementation for Gemma 4's normalization structure: # - Post-attention norm BEFORE residual add # - Pre/post FFN norms - defp gemma4_block_impl(hidden_state, steps, name, spec) do + defp gemma4_block_impl(hidden_state, steps, name, spec, per_layer_inputs) do + # PLE: add per-layer input to hidden state + hidden_state = + if per_layer_inputs do + # Extract layer index from name like "decoder.blocks.5" + idx = + name + |> String.split(".") + |> Enum.at(2) + |> String.to_integer() + + ple_slice = + Axon.nx(per_layer_inputs, fn x -> + x[[.., .., idx, ..]] + end) + + # Gate: sigmoid(gate(ple_slice)) + gate = + Axon.dense(ple_slice, spec.hidden_size, + name: join(name, "per_layer_input_gate"), + use_bias: false + ) + |> Axon.sigmoid() + + # Projection: project ple_slice to hidden_size + projection = + Axon.dense(ple_slice, spec.hidden_size, + name: join(name, "per_layer_projection"), + use_bias: false + ) + + # Gated projection + gated = Axon.multiply(gate, projection) + + # Normalize + gated = + Layers.rms_norm(gated, + name: join(name, "post_per_layer_input_norm"), + epsilon: spec.layer_norm_epsilon + ) + + # Scale by layer_scalar + Axon.add(hidden_state, gated) + else + hidden_state + end + + # Rest of the block stays the same shortcut = hidden_state {hidden_state, attention_info} = @@ -490,7 +601,6 @@ defmodule Bumblebee.Text.Gemma4Text do hidden_state = Axon.add(shortcut, hidden_state) - # Handle cross-attention (required by block interface but not used by Gemma 4) {_hidden_state, cross_attention_info} = steps.cross_attention_maybe.(hidden_state, fn _ -> raise "cross attention not supported" @@ -621,7 +731,6 @@ defmodule Bumblebee.Text.Gemma4Text do "per_layer_model_projection" => "model.language_model.per_layer_model_projection", "per_layer_projection_norm" => "model.language_model.per_layer_projection_norm", # PLE per-layer weights - "decoder.blocks.{n}.layer_scalar" => "model.language_model.layers.{n}.layer_scalar", "decoder.blocks.{n}.per_layer_input_gate" => "model.language_model.layers.{n}.per_layer_input_gate", "decoder.blocks.{n}.per_layer_projection" => diff --git a/lib/bumblebee/utils/chunked_file_tensor.ex b/lib/bumblebee/utils/chunked_file_tensor.ex new file mode 100644 index 00000000..fb1e4a77 --- /dev/null +++ b/lib/bumblebee/utils/chunked_file_tensor.ex @@ -0,0 +1,71 @@ +defmodule Bumblebee.Utils.ChunkedFileTensor do + @moduledoc false + + # A lazy tensor that reads from a safetensors shard file in chunks. + # Replaces Safetensors.FileTensor for large tensors to work around macOS's + # pread(2) EINVAL when the byte count exceeds INT_MAX (~2 GB). + defstruct [:path, :byte_offset, :byte_size, :shape, :type] + + # 1 GB chunks — safely below INT_MAX on macOS. + @max_chunk_size 1_073_741_824 + + @doc """ + Wraps a `Safetensors.FileTensor` in a `ChunkedFileTensor`. + """ + def from_file_tensor(%Safetensors.FileTensor{} = ft) do + %__MODULE__{ + path: ft.path, + byte_offset: ft.byte_offset, + byte_size: ft.byte_size, + shape: ft.shape, + type: ft.type + } + end + + @doc """ + Reads `size` bytes at `offset` from an already-open raw file handle, + splitting into ≤1 GB reads to avoid the macOS pread EINVAL limit. + """ + def read_chunked(file, offset, size) when size <= @max_chunk_size do + {:ok, binary} = :file.pread(file, offset, size) + binary + end + + def read_chunked(file, offset, size) do + full_chunks = div(size, @max_chunk_size) + remainder = rem(size, @max_chunk_size) + + chunks = + for i <- 0..(full_chunks - 1) do + {:ok, chunk} = :file.pread(file, offset + i * @max_chunk_size, @max_chunk_size) + chunk + end + + chunks = + if remainder > 0 do + {:ok, tail} = :file.pread(file, offset + full_chunks * @max_chunk_size, remainder) + chunks ++ [tail] + else + chunks + end + + IO.iodata_to_binary(chunks) + end +end + +defimpl Nx.LazyContainer, for: Bumblebee.Utils.ChunkedFileTensor do + alias Bumblebee.Utils.ChunkedFileTensor + + def traverse(lazy, acc, fun) do + template = Nx.template(lazy.shape, lazy.type) + + load = fn -> + File.open!(lazy.path, [:read, :raw], fn file -> + binary = ChunkedFileTensor.read_chunked(file, lazy.byte_offset, lazy.byte_size) + Safetensors.Shared.build_tensor(binary, lazy.shape, lazy.type) + end) + end + + fun.(template, load, acc) + end +end From 1e2a5652e772468c0bb69d2396c68fbe7c5435b1 Mon Sep 17 00:00:00 2001 From: edenamram <74473448+edenamram@users.noreply.github.com> Date: Mon, 15 Jun 2026 12:33:51 +0300 Subject: [PATCH 4/7] Add value_norm option to transformer layers and update Gemma4Text normalization --- lib/bumblebee/layers/transformer.ex | 19 ++- lib/bumblebee/text/gemma4_text.ex | 171 ++++++++++++--------- lib/bumblebee/utils/chunked_file_tensor.ex | 71 --------- 3 files changed, 118 insertions(+), 143 deletions(-) delete mode 100644 lib/bumblebee/utils/chunked_file_tensor.ex diff --git a/lib/bumblebee/layers/transformer.ex b/lib/bumblebee/layers/transformer.ex index e7b18496..075c07d9 100644 --- a/lib/bumblebee/layers/transformer.ex +++ b/lib/bumblebee/layers/transformer.ex @@ -62,7 +62,8 @@ defmodule Bumblebee.Layers.Transformer do :block_type, :attention_scale, :query_norm, - :key_norm + :key_norm, + :value_norm ] opts = @@ -363,7 +364,8 @@ defmodule Bumblebee.Layers.Transformer do attention_scale: nil, rotary_embedding: nil, query_norm: nil, - key_norm: nil + key_norm: nil, + value_norm: nil ]) name = opts[:name] @@ -395,6 +397,7 @@ defmodule Bumblebee.Layers.Transformer do rotary_embedding = opts[:rotary_embedding] query_norm = opts[:query_norm] key_norm = opts[:key_norm] + value_norm = opts[:value_norm] ffn_fun = case ffn do @@ -455,6 +458,7 @@ defmodule Bumblebee.Layers.Transformer do rotary_embedding: rotary_embedding, query_norm: query_norm, key_norm: key_norm, + value_norm: value_norm, name: join(name, "self_attention") ) @@ -781,7 +785,8 @@ defmodule Bumblebee.Layers.Transformer do output_use_bias: true, rotary_embedding: nil, query_norm: nil, - key_norm: nil + key_norm: nil, + value_norm: nil ]) attention_mask = opts[:attention_mask] @@ -797,6 +802,7 @@ defmodule Bumblebee.Layers.Transformer do causal = opts[:causal] attention_window_size = opts[:attention_window_size] attention_scale = opts[:attention_scale] + value_norm = opts[:value_norm] dropout_rate = opts[:dropout_rate] rotary_embedding = opts[:rotary_embedding] query_norm = opts[:query_norm] @@ -855,6 +861,13 @@ defmodule Bumblebee.Layers.Transformer do key end + value = + if value_norm do + value_norm.(value, join(name, "value_norm")) + else + value + end + {query, key} = case rotary_embedding do opts when is_list(opts) -> diff --git a/lib/bumblebee/text/gemma4_text.ex b/lib/bumblebee/text/gemma4_text.ex index 45993336..6fc07ec6 100644 --- a/lib/bumblebee/text/gemma4_text.ex +++ b/lib/bumblebee/text/gemma4_text.ex @@ -362,28 +362,27 @@ defmodule Bumblebee.Text.Gemma4Text do Nx.reshape(x, {batch, seq, num_layers, ple_dim}) end) - # Context-aware: project main embeddings - # Norm is applied before reshape so its weight has shape [num_layers * ple_dim], - # matching HuggingFace's per_layer_projection_norm weight. + # Context-aware: project main embeddings, reshape, then norm. + # The norm weight has shape [ple_dim] because HuggingFace applies it + # after reshaping to [B, S, num_layers, ple_dim]. context_aware = Axon.dense(embeddings, total_ple_dim, name: "per_layer_model_projection", use_bias: false ) - |> Layers.rms_norm( - name: "per_layer_projection_norm", - epsilon: spec.layer_norm_epsilon - ) |> Axon.nx(fn x -> - # Scale by 1/sqrt(hidden_size) scale = Nx.divide(1.0, Nx.sqrt(Nx.tensor(spec.hidden_size, type: Nx.type(x)))) x = Nx.multiply(x, scale) - # Reshape to [B, S, num_layers, ple_dim] shape = Nx.shape(x) batch = elem(shape, 0) seq = elem(shape, 1) Nx.reshape(x, {batch, seq, num_layers, ple_dim}) end) + |> Layers.rms_norm( + name: "per_layer_projection_norm", + shift: 1.0, + epsilon: spec.layer_norm_epsilon + ) # Combine: (token_identity + context_aware) * (1/sqrt(2)) Axon.layer( @@ -429,6 +428,13 @@ defmodule Bumblebee.Text.Gemma4Text do query_norm = &Layers.rms_norm(&1, shift: 1.0, epsilon: spec.layer_norm_epsilon, name: &2) key_norm = &Layers.rms_norm(&1, shift: 1.0, epsilon: spec.layer_norm_epsilon, name: &2) + value_norm = fn value, _name -> + Axon.nx(value, fn x -> + variance = Nx.mean(Nx.multiply(x, x), axes: [-1], keep_axes: true) + Nx.multiply(x, Nx.rsqrt(Nx.add(variance, spec.layer_norm_epsilon))) + end) + end + layer_types = spec.layer_types || generate_layer_types(spec.num_blocks) attention_window_size = fn idx -> @@ -439,20 +445,26 @@ defmodule Bumblebee.Text.Gemma4Text do end rotary_embedding = fn idx -> - base = - case Enum.at(layer_types, idx, :sliding_attention) do - :full_attention -> spec.rotary_embedding_base - :sliding_attention -> spec.rotary_embedding_base_local - end - - [ - position_ids: position_ids, - max_positions: spec.max_positions, - base: base - ] + case Enum.at(layer_types, idx, :sliding_attention) do + :full_attention -> + [ + position_ids: position_ids, + max_positions: spec.max_positions, + base: :math.pow(spec.rotary_embedding_base, spec.partial_rotary_factor), + percentage: spec.partial_rotary_factor + ] + + :sliding_attention -> + [ + position_ids: position_ids, + max_positions: spec.max_positions, + base: spec.rotary_embedding_base_local, + percentage: 1.0 + ] + end end - attention_scale = :math.pow(spec.attention_head_size, -0.5) + attention_scale = 1.0 non_double_wide_count = spec.num_blocks - spec.num_kv_shared_layers Layers.Transformer.blocks(hidden_state, @@ -503,6 +515,7 @@ defmodule Bumblebee.Text.Gemma4Text do attention_window_size: attention_window_size, query_norm: query_norm, key_norm: key_norm, + value_norm: value_norm, query_use_bias: spec.use_attention_bias, key_use_bias: spec.use_attention_bias, value_use_bias: spec.use_attention_bias, @@ -514,54 +527,9 @@ defmodule Bumblebee.Text.Gemma4Text do # Custom block implementation for Gemma 4's normalization structure: # - Post-attention norm BEFORE residual add # - Pre/post FFN norms + # - PLE applied AFTER attention+MLP as a third residual block defp gemma4_block_impl(hidden_state, steps, name, spec, per_layer_inputs) do - # PLE: add per-layer input to hidden state - hidden_state = - if per_layer_inputs do - # Extract layer index from name like "decoder.blocks.5" - idx = - name - |> String.split(".") - |> Enum.at(2) - |> String.to_integer() - - ple_slice = - Axon.nx(per_layer_inputs, fn x -> - x[[.., .., idx, ..]] - end) - - # Gate: sigmoid(gate(ple_slice)) - gate = - Axon.dense(ple_slice, spec.hidden_size, - name: join(name, "per_layer_input_gate"), - use_bias: false - ) - |> Axon.sigmoid() - - # Projection: project ple_slice to hidden_size - projection = - Axon.dense(ple_slice, spec.hidden_size, - name: join(name, "per_layer_projection"), - use_bias: false - ) - - # Gated projection - gated = Axon.multiply(gate, projection) - - # Normalize - gated = - Layers.rms_norm(gated, - name: join(name, "post_per_layer_input_norm"), - epsilon: spec.layer_norm_epsilon - ) - - # Scale by layer_scalar - Axon.add(hidden_state, gated) - else - hidden_state - end - - # Rest of the block stays the same + # 1. Self-attention with pre/post norms shortcut = hidden_state {hidden_state, attention_info} = @@ -579,6 +547,7 @@ defmodule Bumblebee.Text.Gemma4Text do hidden_state = Axon.add(shortcut, hidden_state) + # 2. FFN with pre/post norms shortcut = hidden_state hidden_state = @@ -601,6 +570,68 @@ defmodule Bumblebee.Text.Gemma4Text do hidden_state = Axon.add(shortcut, hidden_state) + # 3. PLE: gate hidden_state down to PLE dim, multiply with PLE signal, project back up + hidden_state = + if per_layer_inputs do + idx = + name + |> String.split(".") + |> Enum.at(2) + |> String.to_integer() + + ple_slice = + Axon.nx(per_layer_inputs, fn x -> + x[[.., .., idx, ..]] + end) + + shortcut_ple = hidden_state + + # Gate: project hidden_state DOWN to PLE dimension + gated = + Axon.dense(hidden_state, spec.hidden_size_per_layer_input, + name: join(name, "per_layer_input_gate"), + use_bias: false + ) + + # Activation (gelu_approx_tanh, same as FFN) + gated = Layers.activation(gated, spec.activation) + + # Element-wise multiply with PLE signal + gated = Axon.multiply(gated, ple_slice) + + # Project back UP to hidden dimension + gated = + Axon.dense(gated, spec.hidden_size, + name: join(name, "per_layer_projection"), + use_bias: false + ) + + # Normalize + gated = + Layers.rms_norm(gated, + shift: 1.0, + name: join(name, "post_per_layer_input_norm"), + epsilon: spec.layer_norm_epsilon + ) + + Axon.add(shortcut_ple, gated) + else + hidden_state + end + + # 4. Layer scalar: multiply output by per-layer learned scalar + layer_scalar = + Axon.param("layer_scalar", fn _ -> {1} end, initializer: :ones) + + hidden_state = + Axon.layer( + fn hidden_state, scalar, _opts -> + Nx.multiply(hidden_state, scalar) + end, + [hidden_state, layer_scalar], + name: join(name, "layer_scalar") + ) + {_hidden_state, cross_attention_info} = steps.cross_attention_maybe.(hidden_state, fn _ -> raise "cross attention not supported" @@ -730,13 +761,15 @@ defmodule Bumblebee.Text.Gemma4Text do "embedder.token_embedding_per_layer" => "model.language_model.embed_tokens_per_layer", "per_layer_model_projection" => "model.language_model.per_layer_model_projection", "per_layer_projection_norm" => "model.language_model.per_layer_projection_norm", - # PLE per-layer weights "decoder.blocks.{n}.per_layer_input_gate" => "model.language_model.layers.{n}.per_layer_input_gate", "decoder.blocks.{n}.per_layer_projection" => "model.language_model.layers.{n}.per_layer_projection", "decoder.blocks.{n}.post_per_layer_input_norm" => "model.language_model.layers.{n}.post_per_layer_input_norm", + # Per-layer scalar + "decoder.blocks.{n}.layer_scalar" => + "model.language_model.layers.{n}", # Attention projections "decoder.blocks.{n}.self_attention.query" => "model.language_model.layers.{n}.self_attn.q_proj", diff --git a/lib/bumblebee/utils/chunked_file_tensor.ex b/lib/bumblebee/utils/chunked_file_tensor.ex deleted file mode 100644 index fb1e4a77..00000000 --- a/lib/bumblebee/utils/chunked_file_tensor.ex +++ /dev/null @@ -1,71 +0,0 @@ -defmodule Bumblebee.Utils.ChunkedFileTensor do - @moduledoc false - - # A lazy tensor that reads from a safetensors shard file in chunks. - # Replaces Safetensors.FileTensor for large tensors to work around macOS's - # pread(2) EINVAL when the byte count exceeds INT_MAX (~2 GB). - defstruct [:path, :byte_offset, :byte_size, :shape, :type] - - # 1 GB chunks — safely below INT_MAX on macOS. - @max_chunk_size 1_073_741_824 - - @doc """ - Wraps a `Safetensors.FileTensor` in a `ChunkedFileTensor`. - """ - def from_file_tensor(%Safetensors.FileTensor{} = ft) do - %__MODULE__{ - path: ft.path, - byte_offset: ft.byte_offset, - byte_size: ft.byte_size, - shape: ft.shape, - type: ft.type - } - end - - @doc """ - Reads `size` bytes at `offset` from an already-open raw file handle, - splitting into ≤1 GB reads to avoid the macOS pread EINVAL limit. - """ - def read_chunked(file, offset, size) when size <= @max_chunk_size do - {:ok, binary} = :file.pread(file, offset, size) - binary - end - - def read_chunked(file, offset, size) do - full_chunks = div(size, @max_chunk_size) - remainder = rem(size, @max_chunk_size) - - chunks = - for i <- 0..(full_chunks - 1) do - {:ok, chunk} = :file.pread(file, offset + i * @max_chunk_size, @max_chunk_size) - chunk - end - - chunks = - if remainder > 0 do - {:ok, tail} = :file.pread(file, offset + full_chunks * @max_chunk_size, remainder) - chunks ++ [tail] - else - chunks - end - - IO.iodata_to_binary(chunks) - end -end - -defimpl Nx.LazyContainer, for: Bumblebee.Utils.ChunkedFileTensor do - alias Bumblebee.Utils.ChunkedFileTensor - - def traverse(lazy, acc, fun) do - template = Nx.template(lazy.shape, lazy.type) - - load = fn -> - File.open!(lazy.path, [:read, :raw], fn file -> - binary = ChunkedFileTensor.read_chunked(file, lazy.byte_offset, lazy.byte_size) - Safetensors.Shared.build_tensor(binary, lazy.shape, lazy.type) - end) - end - - fun.(template, load, acc) - end -end From 288ed5f9a0256d72732818ae17a878e753e3d1f4 Mon Sep 17 00:00:00 2001 From: edenamram <74473448+edenamram@users.noreply.github.com> Date: Mon, 15 Jun 2026 16:35:10 +0300 Subject: [PATCH 5/7] Add rotary_dim option to rotary embedding layer and update Gemma4Text configuration --- lib/bumblebee/layers.ex | 36 +++++++++++++++++++++++++---- lib/bumblebee/layers/transformer.ex | 1 + lib/bumblebee/text/gemma4_text.ex | 6 +++-- 3 files changed, 37 insertions(+), 6 deletions(-) diff --git a/lib/bumblebee/layers.ex b/lib/bumblebee/layers.ex index 27d81990..bfca3944 100644 --- a/lib/bumblebee/layers.ex +++ b/lib/bumblebee/layers.ex @@ -1237,7 +1237,7 @@ defmodule Bumblebee.Layers do Adds a rotary embedding layer to the network. """ def rotary_embedding(query, key, position_ids, attention_mask, size, opts \\ []) do - opts = Keyword.validate!(opts, [:name, :scaling_strategy, max_positions: 2048, base: 10_000]) + opts = Keyword.validate!(opts, [:name, :scaling_strategy, :rotary_dim, max_positions: 2048, base: 10_000]) output = Axon.layer( @@ -1254,15 +1254,24 @@ defmodule Bumblebee.Layers do max_positions, size, base, - scaling_strategy + scaling_strategy, + rotary_dim \\ nil ) do position = Nx.iota({sequence_length}) - range = Nx.iota({div(size, 2)}) |> Nx.multiply(2) |> Nx.divide(size) + {num_freqs, denominator} = + if rotary_dim do + {div(rotary_dim, 2), size} + else + {div(size, 2), size} + end + + range = Nx.iota({num_freqs}) |> Nx.multiply(2) |> Nx.divide(denominator) case scaling_strategy do %{type: :linear, factor: factor} -> inv_frequency = inv_frequency(base, range) + inv_frequency = maybe_pad_inv_frequency(inv_frequency, div(size, 2), rotary_dim) position = Nx.divide(position, factor) positions_cos_sin(position, inv_frequency) @@ -1273,6 +1282,7 @@ defmodule Bumblebee.Layers do |> Nx.pow(size / (size - 2)) inv_frequency = inv_frequency(base, range) + inv_frequency = maybe_pad_inv_frequency(inv_frequency, div(size, 2), rotary_dim) positions_cos_sin(position, inv_frequency) %{ @@ -1300,6 +1310,7 @@ defmodule Bumblebee.Layers do end inv_frequency = inv_frequency(base, range) |> Nx.divide(factor) + inv_frequency = maybe_pad_inv_frequency(inv_frequency, div(size, 2), rotary_dim) {cos, sin} = positions_cos_sin(position, inv_frequency) {Nx.multiply(cos, cos_sin_factor), Nx.multiply(sin, cos_sin_factor)} @@ -1321,14 +1332,29 @@ defmodule Bumblebee.Layers do original_max_positions ) + inv_frequency = maybe_pad_inv_frequency(inv_frequency, div(size, 2), rotary_dim) positions_cos_sin(position, inv_frequency) _other -> inv_frequency = inv_frequency(base, range) + inv_frequency = maybe_pad_inv_frequency(inv_frequency, div(size, 2), rotary_dim) positions_cos_sin(position, inv_frequency) end end + defp maybe_pad_inv_frequency(inv_frequency, _target_size, nil), do: inv_frequency + + defp maybe_pad_inv_frequency(inv_frequency, target_size, _rotary_dim) do + pad_size = target_size - Nx.axis_size(inv_frequency, 0) + + if pad_size > 0 do + padding = Nx.broadcast(Nx.tensor(0.0, type: Nx.type(inv_frequency)), {pad_size}) + Nx.concatenate([inv_frequency, padding]) + else + inv_frequency + end + end + defnp llama3_inv_frequency( inv_frequency, factor, @@ -1381,6 +1407,7 @@ defmodule Bumblebee.Layers do keyword!(opts, [ :size, :scaling_strategy, + :rotary_dim, mode: :inference, max_positions: 2048, base: 10_000 @@ -1400,7 +1427,8 @@ defmodule Bumblebee.Layers do opts[:max_positions], opts[:size], opts[:base], - opts[:scaling_strategy] + opts[:scaling_strategy], + opts[:rotary_dim] ) position_ids = Nx.as_type(position_ids, :s64) diff --git a/lib/bumblebee/layers/transformer.ex b/lib/bumblebee/layers/transformer.ex index 075c07d9..8289f2a9 100644 --- a/lib/bumblebee/layers/transformer.ex +++ b/lib/bumblebee/layers/transformer.ex @@ -878,6 +878,7 @@ defmodule Bumblebee.Layers.Transformer do :position_ids, :max_positions, :scaling_strategy, + :rotary_dim, base: 10_000, percentage: 1.0 ]) diff --git a/lib/bumblebee/text/gemma4_text.ex b/lib/bumblebee/text/gemma4_text.ex index 6fc07ec6..4356d13d 100644 --- a/lib/bumblebee/text/gemma4_text.ex +++ b/lib/bumblebee/text/gemma4_text.ex @@ -450,8 +450,10 @@ defmodule Bumblebee.Text.Gemma4Text do [ position_ids: position_ids, max_positions: spec.max_positions, - base: :math.pow(spec.rotary_embedding_base, spec.partial_rotary_factor), - percentage: spec.partial_rotary_factor + base: spec.rotary_embedding_base, + percentage: 1.0, + rotary_dim: + trunc(spec.global_attention_head_size * spec.partial_rotary_factor) ] :sliding_attention -> From 0d707a6083bc0c01608f9392f3b9e83dc645c4f9 Mon Sep 17 00:00:00 2001 From: edenamram <74473448+edenamram@users.noreply.github.com> Date: Tue, 16 Jun 2026 14:40:38 +0300 Subject: [PATCH 6/7] Refactor rotary embedding options and improve parameter mapping in Gemma4Text --- lib/bumblebee/layers.ex | 9 ++- lib/bumblebee/text/gemma4_text.ex | 129 +++++++++++++++--------------- 2 files changed, 73 insertions(+), 65 deletions(-) diff --git a/lib/bumblebee/layers.ex b/lib/bumblebee/layers.ex index bfca3944..49d333c6 100644 --- a/lib/bumblebee/layers.ex +++ b/lib/bumblebee/layers.ex @@ -1237,7 +1237,14 @@ defmodule Bumblebee.Layers do Adds a rotary embedding layer to the network. """ def rotary_embedding(query, key, position_ids, attention_mask, size, opts \\ []) do - opts = Keyword.validate!(opts, [:name, :scaling_strategy, :rotary_dim, max_positions: 2048, base: 10_000]) + opts = + Keyword.validate!(opts, [ + :name, + :scaling_strategy, + :rotary_dim, + max_positions: 2048, + base: 10_000 + ]) output = Axon.layer( diff --git a/lib/bumblebee/text/gemma4_text.ex b/lib/bumblebee/text/gemma4_text.ex index 4356d13d..0b593884 100644 --- a/lib/bumblebee/text/gemma4_text.ex +++ b/lib/bumblebee/text/gemma4_text.ex @@ -452,8 +452,7 @@ defmodule Bumblebee.Text.Gemma4Text do max_positions: spec.max_positions, base: spec.rotary_embedding_base, percentage: 1.0, - rotary_dim: - trunc(spec.global_attention_head_size * spec.partial_rotary_factor) + rotary_dim: trunc(spec.global_attention_head_size * spec.partial_rotary_factor) ] :sliding_attention -> @@ -588,14 +587,13 @@ defmodule Bumblebee.Text.Gemma4Text do shortcut_ple = hidden_state - # Gate: project hidden_state DOWN to PLE dimension + # Gate: project hidden_state DOWN to PLE dimension, then activation gated = Axon.dense(hidden_state, spec.hidden_size_per_layer_input, name: join(name, "per_layer_input_gate"), use_bias: false ) - # Activation (gelu_approx_tanh, same as FFN) gated = Layers.activation(gated, spec.activation) # Element-wise multiply with PLE signal @@ -622,18 +620,19 @@ defmodule Bumblebee.Text.Gemma4Text do end # 4. Layer scalar: multiply output by per-layer learned scalar - layer_scalar = - Axon.param("layer_scalar", fn _ -> {1} end, initializer: :ones) - hidden_state = Axon.layer( fn hidden_state, scalar, _opts -> - Nx.multiply(hidden_state, scalar) + Nx.multiply(hidden_state, Nx.reshape(scalar, {})) end, - [hidden_state, layer_scalar], - name: join(name, "layer_scalar") + [ + hidden_state, + Axon.param("layer_scalar", fn _ -> {1} end, initializer: Axon.Initializers.ones()) + ], + name: join(name, "layer_scalar_op") ) + # Handle cross-attention (required by block interface but not used by Gemma 4) {_hidden_state, cross_attention_info} = steps.cross_attention_maybe.(hidden_state, fn _ -> raise "cross attention not supported" @@ -756,59 +755,61 @@ defmodule Bumblebee.Text.Gemma4Text do end defimpl Bumblebee.HuggingFace.Transformers.Model do - def params_mapping(spec) do - %{ - "embedder.token_embedding" => "model.language_model.embed_tokens", - # PLE (Per-Layer Embeddings) global weights - "embedder.token_embedding_per_layer" => "model.language_model.embed_tokens_per_layer", - "per_layer_model_projection" => "model.language_model.per_layer_model_projection", - "per_layer_projection_norm" => "model.language_model.per_layer_projection_norm", - "decoder.blocks.{n}.per_layer_input_gate" => - "model.language_model.layers.{n}.per_layer_input_gate", - "decoder.blocks.{n}.per_layer_projection" => - "model.language_model.layers.{n}.per_layer_projection", - "decoder.blocks.{n}.post_per_layer_input_norm" => - "model.language_model.layers.{n}.post_per_layer_input_norm", - # Per-layer scalar - "decoder.blocks.{n}.layer_scalar" => - "model.language_model.layers.{n}", - # Attention projections - "decoder.blocks.{n}.self_attention.query" => - "model.language_model.layers.{n}.self_attn.q_proj", - "decoder.blocks.{n}.self_attention.key" => - "model.language_model.layers.{n}.self_attn.k_proj", - "decoder.blocks.{n}.self_attention.value" => - "model.language_model.layers.{n}.self_attn.v_proj", - "decoder.blocks.{n}.self_attention.output" => - "model.language_model.layers.{n}.self_attn.o_proj", - # QK-norm - "decoder.blocks.{n}.self_attention.query_norm" => - "model.language_model.layers.{n}.self_attn.q_norm", - "decoder.blocks.{n}.self_attention.key_norm" => - "model.language_model.layers.{n}.self_attn.k_norm", - # Layer norms - "decoder.blocks.{n}.self_attention_norm" => - "model.language_model.layers.{n}.input_layernorm", - "decoder.blocks.{n}.post_attention_norm" => - "model.language_model.layers.{n}.post_attention_layernorm", - # FFN layer norms - "decoder.blocks.{n}.pre_ffn_norm" => - "model.language_model.layers.{n}.pre_feedforward_layernorm", - "decoder.blocks.{n}.post_ffn_norm" => - "model.language_model.layers.{n}.post_feedforward_layernorm", - # FFN projections - "decoder.blocks.{n}.ffn.gate" => "model.language_model.layers.{n}.mlp.gate_proj", - "decoder.blocks.{n}.ffn.intermediate" => "model.language_model.layers.{n}.mlp.up_proj", - "decoder.blocks.{n}.ffn.output" => "model.language_model.layers.{n}.mlp.down_proj", - # Output - "output_norm" => "model.language_model.norm", - "language_modeling_head.output" => - if(spec.tie_word_embeddings, - do: "model.language_model.embed_tokens", - else: "lm_head" - ), - "sequence_classification_head.output" => "score" - } - end + def params_mapping(spec) do + %{ + "embedder.token_embedding" => "model.language_model.embed_tokens", + # PLE global weights + "embedder.token_embedding_per_layer" => "model.language_model.embed_tokens_per_layer", + "per_layer_model_projection" => "model.language_model.per_layer_model_projection", + "per_layer_projection_norm" => "model.language_model.per_layer_projection_norm", + # PLE per-layer weights + "decoder.blocks.{n}.per_layer_input_gate" => + "model.language_model.layers.{n}.per_layer_input_gate", + "decoder.blocks.{n}.per_layer_projection" => + "model.language_model.layers.{n}.per_layer_projection", + "decoder.blocks.{n}.post_per_layer_input_norm" => + "model.language_model.layers.{n}.post_per_layer_input_norm", + # Layer scalar + "decoder.blocks.{n}.layer_scalar_op" => "model.language_model.layers.{n}", + "decoder.blocks.{n}.layer_scalar_op.layer_scalar" => + "model.language_model.layers.{n}.layer_scalar", + # Attention projections + "decoder.blocks.{n}.self_attention.query" => + "model.language_model.layers.{n}.self_attn.q_proj", + "decoder.blocks.{n}.self_attention.key" => + "model.language_model.layers.{n}.self_attn.k_proj", + "decoder.blocks.{n}.self_attention.value" => + "model.language_model.layers.{n}.self_attn.v_proj", + "decoder.blocks.{n}.self_attention.output" => + "model.language_model.layers.{n}.self_attn.o_proj", + # QK-norm + "decoder.blocks.{n}.self_attention.query_norm" => + "model.language_model.layers.{n}.self_attn.q_norm", + "decoder.blocks.{n}.self_attention.key_norm" => + "model.language_model.layers.{n}.self_attn.k_norm", + # Layer norms + "decoder.blocks.{n}.self_attention_norm" => + "model.language_model.layers.{n}.input_layernorm", + "decoder.blocks.{n}.post_attention_norm" => + "model.language_model.layers.{n}.post_attention_layernorm", + # FFN layer norms + "decoder.blocks.{n}.pre_ffn_norm" => + "model.language_model.layers.{n}.pre_feedforward_layernorm", + "decoder.blocks.{n}.post_ffn_norm" => + "model.language_model.layers.{n}.post_feedforward_layernorm", + # FFN projections + "decoder.blocks.{n}.ffn.gate" => "model.language_model.layers.{n}.mlp.gate_proj", + "decoder.blocks.{n}.ffn.intermediate" => "model.language_model.layers.{n}.mlp.up_proj", + "decoder.blocks.{n}.ffn.output" => "model.language_model.layers.{n}.mlp.down_proj", + # Output + "output_norm" => "model.language_model.norm", + "language_modeling_head.output" => + if(spec.tie_word_embeddings, + do: "model.language_model.embed_tokens", + else: "lm_head" + ), + "sequence_classification_head.output" => "score" + } + end end end From 425111c1a718a2357eb778e5159ad9612c039757 Mon Sep 17 00:00:00 2001 From: edenamram <74473448+edenamram@users.noreply.github.com> Date: Thu, 18 Jun 2026 10:05:11 +0300 Subject: [PATCH 7/7] Refactor Gemma4Text to remove shift parameter from normalization layers and enhance attention block implementation --- lib/bumblebee/text/gemma4_text.ex | 399 ++++++++++++++++++++---------- 1 file changed, 275 insertions(+), 124 deletions(-) diff --git a/lib/bumblebee/text/gemma4_text.ex b/lib/bumblebee/text/gemma4_text.ex index 0b593884..bb1d4394 100644 --- a/lib/bumblebee/text/gemma4_text.ex +++ b/lib/bumblebee/text/gemma4_text.ex @@ -328,7 +328,6 @@ defmodule Bumblebee.Text.Gemma4Text do hidden_state = Layers.rms_norm(decoder_outputs.hidden_state, name: "output_norm", - shift: 1.0, epsilon: spec.layer_norm_epsilon, upcast: :all ) @@ -380,7 +379,6 @@ defmodule Bumblebee.Text.Gemma4Text do end) |> Layers.rms_norm( name: "per_layer_projection_norm", - shift: 1.0, epsilon: spec.layer_norm_epsilon ) @@ -425,8 +423,8 @@ defmodule Bumblebee.Text.Gemma4Text do name = opts[:name] per_layer_inputs = opts[:per_layer_inputs] - query_norm = &Layers.rms_norm(&1, shift: 1.0, epsilon: spec.layer_norm_epsilon, name: &2) - key_norm = &Layers.rms_norm(&1, shift: 1.0, epsilon: spec.layer_norm_epsilon, name: &2) + query_norm = &Layers.rms_norm(&1, epsilon: spec.layer_norm_epsilon, name: &2) + key_norm = &Layers.rms_norm(&1, epsilon: spec.layer_norm_epsilon, name: &2) value_norm = fn value, _name -> Axon.nx(value, fn x -> @@ -436,111 +434,143 @@ defmodule Bumblebee.Text.Gemma4Text do end layer_types = spec.layer_types || generate_layer_types(spec.num_blocks) + first_kv_shared = spec.num_blocks - spec.num_kv_shared_layers - attention_window_size = fn idx -> - case Enum.at(layer_types, idx, :sliding_attention) do - :full_attention -> nil - :sliding_attention -> {spec.attention_window_size, spec.attention_window_size} - end - end - - rotary_embedding = fn idx -> - case Enum.at(layer_types, idx, :sliding_attention) do - :full_attention -> - [ - position_ids: position_ids, - max_positions: spec.max_positions, - base: spec.rotary_embedding_base, - percentage: 1.0, - rotary_dim: trunc(spec.global_attention_head_size * spec.partial_rotary_factor) - ] - - :sliding_attention -> - [ - position_ids: position_ids, - max_positions: spec.max_positions, - base: spec.rotary_embedding_base_local, - percentage: 1.0 - ] - end - end + # Last occurrence of each layer type before first_kv_shared — these become "store" layers + store_layer_indices = + Enum.reduce(0..(first_kv_shared - 1), %{}, fn idx, acc -> + Map.put(acc, Enum.at(layer_types, idx), idx) + end) attention_scale = 1.0 - non_double_wide_count = spec.num_blocks - spec.num_kv_shared_layers - Layers.Transformer.blocks(hidden_state, - attention_mask: attention_mask, - attention_head_mask: attention_head_mask, + {attention_mask, cache} = Layers.Decoder.cached_attention_mask(attention_mask, cache) + offset = Layers.Decoder.get_cache_offset(cache) + + initial_state = %{ + hidden_state: hidden_state, + hidden_states: Axon.container({hidden_state}), + attentions: Axon.container({}), cache: cache, - num_blocks: spec.num_blocks, - num_attention_heads: spec.num_attention_heads, - num_key_value_heads: spec.num_key_value_heads, - hidden_size: spec.hidden_size, - attention_head_size: fn idx -> - case Enum.at(layer_types, idx, :sliding_attention) do - :full_attention -> spec.global_attention_head_size - :sliding_attention -> spec.attention_head_size - end - end, - attention_scale: attention_scale, - kernel_initializer: kernel_initializer(spec), - layer_norm: - &Layers.rms_norm(&1, - shift: 1.0, - name: &2, - epsilon: spec.layer_norm_epsilon, - upcast: :all - ), - ffn: fn hidden_state, ffn_name -> - idx = - ffn_name - |> String.split(".") - |> Enum.at(2) - |> String.to_integer() - - intermediate_size = - if spec.use_double_wide_mlp and idx >= non_double_wide_count do - spec.intermediate_size * 2 - else - spec.intermediate_size + shared_kv: %{} + } + + outputs = + Enum.reduce(0..(spec.num_blocks - 1), initial_state, fn idx, state -> + layer_type = Enum.at(layer_types, idx) + is_shared = spec.num_kv_shared_layers > 0 and idx >= first_kv_shared + is_store = not is_shared and Map.get(store_layer_indices, layer_type) == idx + + block_name = join(join(name, "blocks"), idx) + block_cache = Layers.Decoder.get_block_cache(state.cache, idx) + block_attention_head_mask = Axon.nx(attention_head_mask, & &1[idx]) + + head_size = + case layer_type do + :full_attention -> spec.global_attention_head_size + :sliding_attention -> spec.attention_head_size end - gated_ffn(hidden_state, intermediate_size, spec.hidden_size, - name: ffn_name, - activation: spec.activation - ) - end, - block_type: &gemma4_block_impl(&1, &2, &3, spec, per_layer_inputs), - causal: true, - rotary_embedding: rotary_embedding, - attention_window_size: attention_window_size, - query_norm: query_norm, - key_norm: key_norm, - value_norm: value_norm, - query_use_bias: spec.use_attention_bias, - key_use_bias: spec.use_attention_bias, - value_use_bias: spec.use_attention_bias, - output_use_bias: spec.use_attention_bias, - name: join(name, "blocks") - ) + num_kv_heads = + case layer_type do + :full_attention -> spec.num_global_key_value_heads || spec.num_key_value_heads + :sliding_attention -> spec.num_key_value_heads + end + + window_size = + case layer_type do + :full_attention -> nil + :sliding_attention -> {spec.attention_window_size, spec.attention_window_size} + end + + rotary_opts = + case layer_type do + :full_attention -> + [ + position_ids: position_ids, + max_positions: spec.max_positions, + base: spec.rotary_embedding_base, + rotary_dim: trunc(spec.global_attention_head_size * spec.partial_rotary_factor) + ] + + :sliding_attention -> + [ + position_ids: position_ids, + max_positions: spec.max_positions, + base: spec.rotary_embedding_base_local + ] + end + + precomputed_kv = if is_shared, do: Map.get(state.shared_kv, layer_type), else: nil + + {block_hidden_state, block_cache, pre_rope_kv} = + gemma4_block( + state.hidden_state, + block_cache, + offset, + idx, + spec, + per_layer_inputs, + precomputed_kv, + %{ + attention_mask: attention_mask, + attention_head_mask: block_attention_head_mask, + rotary_opts: rotary_opts, + window_size: window_size, + head_size: head_size, + num_kv_heads: num_kv_heads, + query_norm: query_norm, + key_norm: key_norm, + value_norm: value_norm, + attention_scale: attention_scale, + first_kv_shared: first_kv_shared, + name: block_name + } + ) + + updated_shared_kv = + if is_store, + do: Map.put(state.shared_kv, layer_type, pre_rope_kv), + else: state.shared_kv + + new_cache = Layers.Decoder.put_block_cache(state.cache, idx, block_cache) + + %{ + hidden_state: block_hidden_state, + hidden_states: Layers.append(state.hidden_states, block_hidden_state), + attentions: Layers.append(state.attentions, Layers.none()), + cache: new_cache, + shared_kv: updated_shared_kv + } + end) + + update_in(outputs.cache, &Layers.Decoder.update_cache_offset(&1, outputs.hidden_state)) end - # Custom block implementation for Gemma 4's normalization structure: - # - Post-attention norm BEFORE residual add - # - Pre/post FFN norms - # - PLE applied AFTER attention+MLP as a third residual block - defp gemma4_block_impl(hidden_state, steps, name, spec, per_layer_inputs) do - # 1. Self-attention with pre/post norms + # Builds one Gemma4 decoder block with: + # - Pre-attention norm, self-attention, post-attention norm, residual + # - Pre-FFN norm, FFN, post-FFN norm, residual + # - Optional PLE block + # - Layer scalar + # Returns {hidden_state, block_cache, pre_rope_kv} where pre_rope_kv is nil for shared layers + defp gemma4_block(hidden_state, block_cache, offset, idx, spec, per_layer_inputs, precomputed_kv, opts) do + name = opts.name + + # 1. Self-attention shortcut = hidden_state - {hidden_state, attention_info} = - hidden_state - |> steps.self_attention_norm.() - |> steps.self_attention.() + hidden_state = + Layers.rms_norm(hidden_state, + name: join(name, "self_attention_norm"), + epsilon: spec.layer_norm_epsilon, + upcast: :all + ) + + {hidden_state, block_cache, pre_rope_kv} = + gemma4_attention(hidden_state, block_cache, offset, spec, precomputed_kv, opts) hidden_state = Layers.rms_norm(hidden_state, - shift: 1.0, name: join(name, "post_attention_norm"), epsilon: spec.layer_norm_epsilon, upcast: :all @@ -548,22 +578,31 @@ defmodule Bumblebee.Text.Gemma4Text do hidden_state = Axon.add(shortcut, hidden_state) - # 2. FFN with pre/post norms + # 2. FFN shortcut = hidden_state hidden_state = Layers.rms_norm(hidden_state, - shift: 1.0, name: join(name, "pre_ffn_norm"), epsilon: spec.layer_norm_epsilon, upcast: :all ) - hidden_state = steps.ffn.(hidden_state) + intermediate_size = + if spec.use_double_wide_mlp and idx >= opts.first_kv_shared do + spec.intermediate_size * 2 + else + spec.intermediate_size + end + + hidden_state = + gated_ffn(hidden_state, intermediate_size, spec.hidden_size, + name: join(name, "ffn"), + activation: spec.activation + ) hidden_state = Layers.rms_norm(hidden_state, - shift: 1.0, name: join(name, "post_ffn_norm"), epsilon: spec.layer_norm_epsilon, upcast: :all @@ -571,23 +610,12 @@ defmodule Bumblebee.Text.Gemma4Text do hidden_state = Axon.add(shortcut, hidden_state) - # 3. PLE: gate hidden_state down to PLE dim, multiply with PLE signal, project back up + # 3. PLE hidden_state = if per_layer_inputs do - idx = - name - |> String.split(".") - |> Enum.at(2) - |> String.to_integer() - - ple_slice = - Axon.nx(per_layer_inputs, fn x -> - x[[.., .., idx, ..]] - end) - + ple_slice = Axon.nx(per_layer_inputs, fn x -> x[[.., .., idx, ..]] end) shortcut_ple = hidden_state - # Gate: project hidden_state DOWN to PLE dimension, then activation gated = Axon.dense(hidden_state, spec.hidden_size_per_layer_input, name: join(name, "per_layer_input_gate"), @@ -595,21 +623,16 @@ defmodule Bumblebee.Text.Gemma4Text do ) gated = Layers.activation(gated, spec.activation) - - # Element-wise multiply with PLE signal gated = Axon.multiply(gated, ple_slice) - # Project back UP to hidden dimension gated = Axon.dense(gated, spec.hidden_size, name: join(name, "per_layer_projection"), use_bias: false ) - # Normalize gated = Layers.rms_norm(gated, - shift: 1.0, name: join(name, "post_per_layer_input_norm"), epsilon: spec.layer_norm_epsilon ) @@ -619,7 +642,7 @@ defmodule Bumblebee.Text.Gemma4Text do hidden_state end - # 4. Layer scalar: multiply output by per-layer learned scalar + # 4. Layer scalar hidden_state = Axon.layer( fn hidden_state, scalar, _opts -> @@ -632,13 +655,141 @@ defmodule Bumblebee.Text.Gemma4Text do name: join(name, "layer_scalar_op") ) - # Handle cross-attention (required by block interface but not used by Gemma 4) - {_hidden_state, cross_attention_info} = - steps.cross_attention_maybe.(hidden_state, fn _ -> - raise "cross attention not supported" - end) + {hidden_state, block_cache, pre_rope_kv} + end + + # Builds self-attention for one Gemma4 block. + # Non-shared layers: compute Q/K/V, apply RoPE to Q+K, GQA-expand, return post-RoPE expanded K/V. + # Shared layers: compute Q only, apply RoPE to Q only, reuse stored post-RoPE K/V from store layer. + # Returns {attention_output, block_cache, storable_kv}. + defp gemma4_attention(hidden_state, block_cache, offset, spec, precomputed_kv, opts) do + name = join(opts.name, "self_attention") + + head_size = opts.head_size + num_kv_heads = opts.num_kv_heads + num_q_heads = spec.num_attention_heads + inner_size = num_q_heads * head_size + inner_kv_size = num_kv_heads * head_size + + rotary_opts = opts.rotary_opts + position_ids = rotary_opts[:position_ids] + + rotary_call_opts = + rotary_opts + |> Keyword.delete(:position_ids) + |> Keyword.put(:name, join(name, "rotary_embedding")) + + # Q projection + split heads + Q-norm (always computed) + query = + hidden_state + |> Axon.dense(inner_size, name: join(name, "query"), use_bias: spec.use_attention_bias) + |> Layers.split_heads(num_q_heads) + |> opts.query_norm.(join(name, "query_norm")) + + {query, key, value, storable_kv} = + if precomputed_kv do + # Shared layer: K/V are already post-RoPE, post-GQA from store layer. + # Apply RoPE to Q only by passing stored key through and discarding the re-rotated key. + {stored_key, stored_value} = precomputed_kv + + {rotated_query, _discarded} = + Layers.rotary_embedding( + query, + stored_key, + position_ids, + opts.attention_mask, + head_size, + rotary_call_opts + ) + + {rotated_query, stored_key, stored_value, nil} + else + # Non-shared layer: compute K/V projections + norms + key = + hidden_state + |> Axon.dense(inner_kv_size, + name: join(name, "key"), + use_bias: spec.use_attention_bias + ) + |> Layers.split_heads(num_kv_heads) + |> opts.key_norm.(join(name, "key_norm")) + + value = + hidden_state + |> Axon.dense(inner_kv_size, + name: join(name, "value"), + use_bias: spec.use_attention_bias + ) + |> Layers.split_heads(num_kv_heads) + |> opts.value_norm.(join(name, "value_norm")) + + # Apply RoPE to both Q and K + {rotated_query, rotated_key} = + Layers.rotary_embedding( + query, + key, + position_ids, + opts.attention_mask, + head_size, + rotary_call_opts + ) + + # GQA: expand K/V heads to match Q heads + num_kv_groups = div(num_q_heads, num_kv_heads) + + expanded_key = + if num_kv_groups > 1, + do: Layers.repeat_interleave(rotated_key, num_kv_groups, axis: 2), + else: rotated_key + + expanded_value = + if num_kv_groups > 1, + do: Layers.repeat_interleave(value, num_kv_groups, axis: 2), + else: value + + # Storable: post-RoPE, post-GQA (matches Python's shared_kv_states) + {rotated_query, expanded_key, expanded_value, {expanded_key, expanded_value}} + end + + # KV cache update + {self_attention_cache, cross_attention_cache} = + Layers.Decoder.get_attention_caches(block_cache) + + {key, value, self_attention_cache} = + Layers.Decoder.cached_attention_key_values(key, value, self_attention_cache, offset) + + # Scaled dot-product attention + {attention_output, _weights} = + Layers.attention( + query, + key, + value, + opts.attention_mask, + opts.attention_head_mask, + Layers.none(), + offset, + scale: opts.attention_scale, + causal: true, + window_size: opts.window_size + ) + + # Output projection + attention_output = + attention_output + |> Layers.flatten_trailing() + |> Axon.dense(spec.hidden_size, + name: join(name, "output"), + use_bias: spec.use_attention_bias + ) + + block_cache = + Layers.Decoder.put_attention_caches( + block_cache, + self_attention_cache, + cross_attention_cache + ) - {hidden_state, attention_info, cross_attention_info} + {attention_output, block_cache, storable_kv} end defp gated_ffn(hidden_state, intermediate_size, output_size, opts) do