@@ -12,7 +12,8 @@ defmodule Bumblebee.Text.Generation do
1212 spec :: Bumblebee.ModelSpec . t ( ) ,
1313 batch_size :: pos_integer ( ) ,
1414 max_length :: pos_integer ( ) ,
15- inputs :: map ( )
15+ inputs :: map ( ) ,
16+ opts :: keyword ( )
1617 ) :: cache ( )
1718
1819 @ doc """
@@ -42,9 +43,10 @@ defmodule Bumblebee.Text.Generation do
4243 @ doc """
4344 Initializes an opaque cache input for iterative inference.
4445 """
45- @ spec init_cache ( Bumblebee.ModelSpec . t ( ) , pos_integer ( ) , pos_integer ( ) , map ( ) ) :: cache ( )
46- def init_cache ( % module { } = spec , batch_size , max_length , inputs ) do
47- module . init_cache ( spec , batch_size , max_length , inputs )
46+ @ spec init_cache ( Bumblebee.ModelSpec . t ( ) , pos_integer ( ) , pos_integer ( ) , map ( ) , keyword ( ) ) ::
47+ cache ( )
48+ def init_cache ( % module { } = spec , batch_size , max_length , inputs , opts \\ [ ] ) do
49+ module . init_cache ( spec , batch_size , max_length , inputs , opts )
4850 end
4951
5052 @ doc """
@@ -313,17 +315,13 @@ defmodule Bumblebee.Text.Generation do
313315 |> Map . put ( prefix <> "position_ids" , position_ids )
314316
315317 batch_size = Nx . axis_size ( input_ids , 0 )
316- cache = init_cache ( spec , batch_size , max_length , inputs )
317318
318319 output_policy = model_output_policy ( model )
319-
320- # Cast all float cache tensors to match the model output. This way
321- # we make sure the cache we pass as input has the same types as
322- # the updated cache returned from the model
323- cache =
324- Bumblebee.Utils.Nx . map ( cache , fn tensor ->
325- Axon.MixedPrecision . cast ( output_policy , tensor , :output )
326- end )
320+ # Use the compute precision as the cache type. The key/value tensors are
321+ # produced by projection layers running in compute precision, so this
322+ # matches what the model will actually return for the cache.
323+ cache_type = output_policy . compute || { :f , 32 }
324+ cache = init_cache ( spec , batch_size , max_length , inputs , cache_type: cache_type )
327325
328326 Map . put ( inputs , "cache" , cache )
329327 end
0 commit comments