Skip to content
9 changes: 8 additions & 1 deletion lib/bumblebee.ex
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@ defmodule Bumblebee do
"Gemma3TextForCausalLM" => {Bumblebee.Text.Gemma3Text, :for_causal_language_modeling},
"Gemma3TextForSequenceClassification" =>
{Bumblebee.Text.Gemma3Text, :for_sequence_classification},
"Gemma4ForConditionalGeneration" =>
{Bumblebee.Text.Gemma4Text, :for_causal_language_modeling},
"GPT2ForSequenceClassification" => {Bumblebee.Text.Gpt2, :for_sequence_classification},
"GPT2ForTokenClassification" => {Bumblebee.Text.Gpt2, :for_token_classification},
"GPT2LMHeadModel" => {Bumblebee.Text.Gpt2, :for_causal_language_modeling},
Expand Down Expand Up @@ -273,6 +275,7 @@ defmodule Bumblebee do
"clip" => :clip,
"gemma" => :gemma,
"gemma3_text" => :gemma,
"gemma4" => :gemma,
"gpt_neox" => :gpt_neo_x,
"gpt2" => :gpt2,
"gpt_bigcode" => :gpt2,
Expand Down Expand Up @@ -777,11 +780,15 @@ defmodule Bumblebee do
end

defp params_file_loader_fun(".safetensors", opts) do
opts[:safetensors_reader] || (&Safetensors.read!(&1, lazy: true))
opts[:safetensors_reader] || (&read_safetensors_chunked/1)
end

defp params_file_loader_fun(_, _opts), do: &Bumblebee.Conversion.PyTorchLoader.load!/1

defp read_safetensors_chunked(path) do
Safetensors.read!(path, lazy: true)
end

@doc """
Featurizes `input` with the given featurizer.

Expand Down
47 changes: 46 additions & 1 deletion lib/bumblebee/conversion/pytorch_params.ex
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ defmodule Bumblebee.Conversion.PyTorchParams do

{value, diff} =
if all_sources_found? do
source_values = Enum.map(source_values, &Nx.to_tensor/1)
source_values = Enum.map(source_values, &lazy_to_tensor/1)
value = builder_fun.(Enum.reverse(source_values))

case verify_param_shape(param_expr, value) do
Expand Down Expand Up @@ -188,6 +188,51 @@ defmodule Bumblebee.Conversion.PyTorchParams do

defp prepend(diff, key, values), do: Map.update!(diff, key, &(values ++ &1))

# macOS pread(2) returns EINVAL when byte count > INT_MAX (~2 GB).
# For large safetensors tensors, read in 1 GB chunks instead.
@pread_chunk 1_073_741_824

defp lazy_to_tensor(%Safetensors.FileTensor{byte_size: size} = ft)
when size > @pread_chunk do
# Force BinaryBackend: the GPU backend (EMLX) cannot allocate tensors
# this large in a single call, and we must also avoid the macOS pread
# INT_MAX limit by reading in chunks.
Nx.with_default_backend(Nx.BinaryBackend, fn ->
File.open!(ft.path, [:read, :raw], fn file ->
binary = pread_chunked(file, ft.byte_offset, ft.byte_size)
Safetensors.Shared.build_tensor(binary, ft.shape, ft.type)
end)
end)
end

defp lazy_to_tensor(value), do: Nx.to_tensor(value)

defp pread_chunked(file, offset, size) when size <= @pread_chunk do
{:ok, binary} = :file.pread(file, offset, size)
binary
end

defp pread_chunked(file, offset, size) do
full = div(size, @pread_chunk)
rest = rem(size, @pread_chunk)

chunks =
for i <- 0..(full - 1) do
{:ok, chunk} = :file.pread(file, offset + i * @pread_chunk, @pread_chunk)
chunk
end

chunks =
if rest > 0 do
{:ok, tail} = :file.pread(file, offset + full * @pread_chunk, rest)
chunks ++ [tail]
else
chunks
end

IO.iodata_to_binary(chunks)
end

defp infer_prefixes(layers, pytorch_state, params_mapping) do
# Note: target refers to the parameters we are initializing, while
# source refers to the state we are loading from
Expand Down
43 changes: 39 additions & 4 deletions lib/bumblebee/layers.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1237,7 +1237,14 @@ defmodule Bumblebee.Layers do
Adds a rotary embedding layer to the network.
"""
def rotary_embedding(query, key, position_ids, attention_mask, size, opts \\ []) do
opts = Keyword.validate!(opts, [:name, :scaling_strategy, max_positions: 2048, base: 10_000])
opts =
Keyword.validate!(opts, [
:name,
:scaling_strategy,
:rotary_dim,
max_positions: 2048,
base: 10_000
])

output =
Axon.layer(
Expand All @@ -1254,15 +1261,24 @@ defmodule Bumblebee.Layers do
max_positions,
size,
base,
scaling_strategy
scaling_strategy,
rotary_dim \\ nil
) do
position = Nx.iota({sequence_length})

range = Nx.iota({div(size, 2)}) |> Nx.multiply(2) |> Nx.divide(size)
{num_freqs, denominator} =
if rotary_dim do
{div(rotary_dim, 2), size}
else
{div(size, 2), size}
end

range = Nx.iota({num_freqs}) |> Nx.multiply(2) |> Nx.divide(denominator)

case scaling_strategy do
%{type: :linear, factor: factor} ->
inv_frequency = inv_frequency(base, range)
inv_frequency = maybe_pad_inv_frequency(inv_frequency, div(size, 2), rotary_dim)
position = Nx.divide(position, factor)
positions_cos_sin(position, inv_frequency)

Expand All @@ -1273,6 +1289,7 @@ defmodule Bumblebee.Layers do
|> Nx.pow(size / (size - 2))

inv_frequency = inv_frequency(base, range)
inv_frequency = maybe_pad_inv_frequency(inv_frequency, div(size, 2), rotary_dim)
positions_cos_sin(position, inv_frequency)

%{
Expand Down Expand Up @@ -1300,6 +1317,7 @@ defmodule Bumblebee.Layers do
end

inv_frequency = inv_frequency(base, range) |> Nx.divide(factor)
inv_frequency = maybe_pad_inv_frequency(inv_frequency, div(size, 2), rotary_dim)
{cos, sin} = positions_cos_sin(position, inv_frequency)
{Nx.multiply(cos, cos_sin_factor), Nx.multiply(sin, cos_sin_factor)}

Expand All @@ -1321,14 +1339,29 @@ defmodule Bumblebee.Layers do
original_max_positions
)

inv_frequency = maybe_pad_inv_frequency(inv_frequency, div(size, 2), rotary_dim)
positions_cos_sin(position, inv_frequency)

_other ->
inv_frequency = inv_frequency(base, range)
inv_frequency = maybe_pad_inv_frequency(inv_frequency, div(size, 2), rotary_dim)
positions_cos_sin(position, inv_frequency)
end
end

defp maybe_pad_inv_frequency(inv_frequency, _target_size, nil), do: inv_frequency

defp maybe_pad_inv_frequency(inv_frequency, target_size, _rotary_dim) do
pad_size = target_size - Nx.axis_size(inv_frequency, 0)

if pad_size > 0 do
padding = Nx.broadcast(Nx.tensor(0.0, type: Nx.type(inv_frequency)), {pad_size})
Nx.concatenate([inv_frequency, padding])
else
inv_frequency
end
end

defnp llama3_inv_frequency(
inv_frequency,
factor,
Expand Down Expand Up @@ -1381,6 +1414,7 @@ defmodule Bumblebee.Layers do
keyword!(opts, [
:size,
:scaling_strategy,
:rotary_dim,
mode: :inference,
max_positions: 2048,
base: 10_000
Expand All @@ -1400,7 +1434,8 @@ defmodule Bumblebee.Layers do
opts[:max_positions],
opts[:size],
opts[:base],
opts[:scaling_strategy]
opts[:scaling_strategy],
opts[:rotary_dim]
)

position_ids = Nx.as_type(position_ids, :s64)
Expand Down
31 changes: 27 additions & 4 deletions lib/bumblebee/layers/transformer.ex
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ defmodule Bumblebee.Layers.Transformer do
:hidden_size,
:ffn,
:kernel_initializer,
:attention_head_size,
:dropout_rate,
:attention_dropout_rate,
:query_use_bias,
Expand All @@ -63,7 +62,8 @@ defmodule Bumblebee.Layers.Transformer do
:block_type,
:attention_scale,
:query_norm,
:key_norm
:key_norm,
:value_norm
]

opts =
Expand All @@ -75,6 +75,7 @@ defmodule Bumblebee.Layers.Transformer do
:num_blocks,
:rotary_embedding,
:attention_window_size,
:attention_head_size,
attention_mask: Layers.none(),
attention_head_mask: Layers.none(),
attention_relative_bias: nil,
Expand All @@ -97,6 +98,7 @@ defmodule Bumblebee.Layers.Transformer do
cache = opts[:cache]
rotary_embedding = opts[:rotary_embedding]
attention_window_size = opts[:attention_window_size]
attention_head_size = opts[:attention_head_size]

block_opts = Keyword.take(opts, block_opts_keys)

Expand Down Expand Up @@ -142,12 +144,20 @@ defmodule Bumblebee.Layers.Transformer do
size -> size
end

block_attention_head_size =
case attention_head_size do
nil -> nil
fun when is_function(fun, 1) -> fun.(idx)
size -> size
end

{hidden_state, attention, cross_attention, block_cache, attention_relative_bias} =
block(
state.hidden_state,
[
attention_mask: attention_mask,
attention_head_mask: block_attention_head_mask,
attention_head_size: block_attention_head_size,
attention_relative_bias: attention_relative_bias,
cross_hidden_state: cross_hidden_state,
cross_attention_mask: cross_attention_mask,
Expand Down Expand Up @@ -354,7 +364,8 @@ defmodule Bumblebee.Layers.Transformer do
attention_scale: nil,
rotary_embedding: nil,
query_norm: nil,
key_norm: nil
key_norm: nil,
value_norm: nil
])

name = opts[:name]
Expand Down Expand Up @@ -386,6 +397,7 @@ defmodule Bumblebee.Layers.Transformer do
rotary_embedding = opts[:rotary_embedding]
query_norm = opts[:query_norm]
key_norm = opts[:key_norm]
value_norm = opts[:value_norm]

ffn_fun =
case ffn do
Expand Down Expand Up @@ -446,6 +458,7 @@ defmodule Bumblebee.Layers.Transformer do
rotary_embedding: rotary_embedding,
query_norm: query_norm,
key_norm: key_norm,
value_norm: value_norm,
name: join(name, "self_attention")
)

Expand Down Expand Up @@ -772,7 +785,8 @@ defmodule Bumblebee.Layers.Transformer do
output_use_bias: true,
rotary_embedding: nil,
query_norm: nil,
key_norm: nil
key_norm: nil,
value_norm: nil
])

attention_mask = opts[:attention_mask]
Expand All @@ -788,6 +802,7 @@ defmodule Bumblebee.Layers.Transformer do
causal = opts[:causal]
attention_window_size = opts[:attention_window_size]
attention_scale = opts[:attention_scale]
value_norm = opts[:value_norm]
dropout_rate = opts[:dropout_rate]
rotary_embedding = opts[:rotary_embedding]
query_norm = opts[:query_norm]
Expand Down Expand Up @@ -846,6 +861,13 @@ defmodule Bumblebee.Layers.Transformer do
key
end

value =
if value_norm do
value_norm.(value, join(name, "value_norm"))
else
value
end

{query, key} =
case rotary_embedding do
opts when is_list(opts) ->
Expand All @@ -856,6 +878,7 @@ defmodule Bumblebee.Layers.Transformer do
:position_ids,
:max_positions,
:scaling_strategy,
:rotary_dim,
base: 10_000,
percentage: 1.0
])
Expand Down
Loading