@@ -568,6 +568,13 @@ defmodule Bumblebee do
568568
569569 * `:params_filename` - the file with the model parameters to be loaded
570570
571+ * `:safetensors_reader` - a 1-arity function used to read `.safetensors`
572+ parameter files. Receives the file path and must return a map from
573+ tensor name to an `Nx.Tensor` or any term implementing
574+ `Nx.LazyContainer`. Defaults to `&Safetensors.read!(&1, lazy: true)`.
575+ Override to plug in a custom reader, for example one that
576+ memory-maps the file for zero-copy loading
577+
571578 * `:log_params_diff` - whether to log missing, mismatched and unused
572579 parameters. By default diff is logged only if some parameters
573580 cannot be loaded
@@ -617,6 +624,7 @@ defmodule Bumblebee do
617624 :architecture ,
618625 :params_variant ,
619626 :params_filename ,
627+ :safetensors_reader ,
620628 :log_params_diff ,
621629 :backend ,
622630 :type
@@ -659,7 +667,7 @@ defmodule Bumblebee do
659667 filename
660668 |> String . replace_suffix ( ".index.json" , "" )
661669 |> Path . extname ( )
662- |> params_file_loader_fun ( )
670+ |> params_file_loader_fun ( opts )
663671
664672 with { :ok , paths } <- download_params_files ( repository , repo_files , filename , sharded? ) do
665673 opts =
@@ -768,8 +776,11 @@ defmodule Bumblebee do
768776 end
769777 end
770778
771- defp params_file_loader_fun ( ".safetensors" ) , do: & Safetensors . read! ( & 1 , lazy: true )
772- defp params_file_loader_fun ( _ ) , do: & Bumblebee.Conversion.PyTorchLoader . load! / 1
779+ defp params_file_loader_fun ( ".safetensors" , opts ) do
780+ opts [ :safetensors_reader ] || ( & Safetensors . read! ( & 1 , lazy: true ) )
781+ end
782+
783+ defp params_file_loader_fun ( _ , _opts ) , do: & Bumblebee.Conversion.PyTorchLoader . load! / 1
773784
774785 @ doc """
775786 Featurizes `input` with the given featurizer.
0 commit comments