Skip to content

Commit d0774e8

Browse files
authored
Add :safetensors_reader option to load_model/2 (#456)
1 parent 851632d commit d0774e8

2 files changed

Lines changed: 39 additions & 3 deletions

File tree

lib/bumblebee.ex

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -568,6 +568,13 @@ defmodule Bumblebee do
568568
569569
* `:params_filename` - the file with the model parameters to be loaded
570570
571+
* `:safetensors_reader` - a 1-arity function used to read `.safetensors`
572+
parameter files. Receives the file path and must return a map from
573+
tensor name to an `Nx.Tensor` or any term implementing
574+
`Nx.LazyContainer`. Defaults to `&Safetensors.read!(&1, lazy: true)`.
575+
Override to plug in a custom reader, for example one that
576+
memory-maps the file for zero-copy loading
577+
571578
* `:log_params_diff` - whether to log missing, mismatched and unused
572579
parameters. By default diff is logged only if some parameters
573580
cannot be loaded
@@ -617,6 +624,7 @@ defmodule Bumblebee do
617624
:architecture,
618625
:params_variant,
619626
:params_filename,
627+
:safetensors_reader,
620628
:log_params_diff,
621629
:backend,
622630
:type
@@ -659,7 +667,7 @@ defmodule Bumblebee do
659667
filename
660668
|> String.replace_suffix(".index.json", "")
661669
|> Path.extname()
662-
|> params_file_loader_fun()
670+
|> params_file_loader_fun(opts)
663671

664672
with {:ok, paths} <- download_params_files(repository, repo_files, filename, sharded?) do
665673
opts =
@@ -768,8 +776,11 @@ defmodule Bumblebee do
768776
end
769777
end
770778

771-
defp params_file_loader_fun(".safetensors"), do: &Safetensors.read!(&1, lazy: true)
772-
defp params_file_loader_fun(_), do: &Bumblebee.Conversion.PyTorchLoader.load!/1
779+
defp params_file_loader_fun(".safetensors", opts) do
780+
opts[:safetensors_reader] || (&Safetensors.read!(&1, lazy: true))
781+
end
782+
783+
defp params_file_loader_fun(_, _opts), do: &Bumblebee.Conversion.PyTorchLoader.load!/1
773784

774785
@doc """
775786
Featurizes `input` with the given featurizer.

test/bumblebee_test.exs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,5 +84,30 @@ defmodule BumblebeeTest do
8484
assert Nx.type(params["decoder.blocks.0.ffn.output"]["kernel"]) == {:f, 16}
8585
assert Nx.type(params["decoder.blocks.0.ffn.output"]["bias"]) == {:f, 16}
8686
end
87+
88+
test "uses :safetensors_reader to read .safetensors files" do
89+
test_pid = self()
90+
91+
reader = fn path ->
92+
send(test_pid, {:reader_called, path})
93+
Safetensors.read!(path, lazy: true)
94+
end
95+
96+
assert {:ok, %{params: params}} =
97+
Bumblebee.load_model(
98+
{:hf, "bumblebee-testing/tiny-random-GPT2Model-safetensors-only"},
99+
safetensors_reader: reader
100+
)
101+
102+
assert_received {:reader_called, path}
103+
assert File.exists?(path)
104+
105+
assert {:ok, %{params: default_params}} =
106+
Bumblebee.load_model(
107+
{:hf, "bumblebee-testing/tiny-random-GPT2Model-safetensors-only"}
108+
)
109+
110+
assert Enum.sort(Map.keys(params.data)) == Enum.sort(Map.keys(default_params.data))
111+
end
87112
end
88113
end

0 commit comments

Comments
 (0)