Skip to content

Commit ccd6f4a

Browse files
committed
refactor: pass type options for init_cache
1 parent 080087a commit ccd6f4a

25 files changed

Lines changed: 89 additions & 69 deletions

lib/bumblebee/audio/whisper.ex

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ defmodule Bumblebee.Audio.Whisper do
227227
end
228228

229229
@impl true
230-
def init_cache(spec, batch_size, max_length, inputs) do
230+
def init_cache(spec, batch_size, max_length, inputs, opts \\ []) do
231231
encoder_sequence_length =
232232
if encoder_hidden_state = inputs["encoder_hidden_state"] do
233233
Nx.axis_size(encoder_hidden_state, 1)
@@ -238,7 +238,8 @@ defmodule Bumblebee.Audio.Whisper do
238238
decoder_num_attention_heads: spec.decoder_num_attention_heads,
239239
encoder_num_attention_heads: spec.encoder_num_attention_heads,
240240
decoder_num_blocks: spec.decoder_num_blocks,
241-
encoder_sequence_length: encoder_sequence_length
241+
encoder_sequence_length: encoder_sequence_length,
242+
attention_cache_type: opts[:cache_type]
242243
)
243244
end
244245

lib/bumblebee/multimodal/blip.ex

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,8 @@ defmodule Bumblebee.Multimodal.Blip do
178178
%{vision_spec: vision_spec, text_spec: text_spec},
179179
batch_size,
180180
max_length,
181-
inputs
181+
inputs,
182+
opts \\ []
182183
) do
183184
num_patches = div(vision_spec.image_size, vision_spec.patch_size) ** 2
184185
encoder_sequence_length = num_patches + 1
@@ -193,7 +194,7 @@ defmodule Bumblebee.Multimodal.Blip do
193194
}
194195
|> Map.reject(&match?({_, nil}, &1))
195196

196-
text_spec.__struct__.init_cache(text_spec, batch_size, max_length, inputs)
197+
text_spec.__struct__.init_cache(text_spec, batch_size, max_length, inputs, opts)
197198
end
198199

199200
@impl true

lib/bumblebee/text/bart.ex

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,7 @@ defmodule Bumblebee.Text.Bart do
417417
end
418418

419419
@impl true
420-
def init_cache(spec, batch_size, max_length, inputs) do
420+
def init_cache(spec, batch_size, max_length, inputs, opts \\ []) do
421421
encoder_sequence_length =
422422
if encoder_hidden_state = inputs["encoder_hidden_state"] do
423423
Nx.axis_size(encoder_hidden_state, 1)
@@ -428,7 +428,8 @@ defmodule Bumblebee.Text.Bart do
428428
decoder_num_attention_heads: spec.decoder_num_attention_heads,
429429
encoder_num_attention_heads: spec.encoder_num_attention_heads,
430430
decoder_num_blocks: spec.decoder_num_blocks,
431-
encoder_sequence_length: encoder_sequence_length
431+
encoder_sequence_length: encoder_sequence_length,
432+
attention_cache_type: opts[:cache_type]
432433
)
433434
end
434435

lib/bumblebee/text/bert.ex

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ defmodule Bumblebee.Text.Bert do
374374
end
375375

376376
@impl true
377-
def init_cache(spec, batch_size, max_length, inputs) do
377+
def init_cache(spec, batch_size, max_length, inputs, opts \\ []) do
378378
encoder_sequence_length =
379379
if encoder_hidden_state = inputs["encoder_hidden_state"] do
380380
Nx.axis_size(encoder_hidden_state, 1)
@@ -385,7 +385,8 @@ defmodule Bumblebee.Text.Bert do
385385
decoder_num_attention_heads: spec.num_attention_heads,
386386
encoder_num_attention_heads: spec.num_attention_heads,
387387
decoder_num_blocks: spec.num_blocks,
388-
encoder_sequence_length: encoder_sequence_length
388+
encoder_sequence_length: encoder_sequence_length,
389+
attention_cache_type: opts[:cache_type]
389390
)
390391
end
391392

lib/bumblebee/text/blenderbot.ex

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ defmodule Bumblebee.Text.Blenderbot do
269269
end
270270

271271
@impl true
272-
def init_cache(spec, batch_size, max_length, inputs) do
272+
def init_cache(spec, batch_size, max_length, inputs, opts \\ []) do
273273
encoder_sequence_length =
274274
if encoder_hidden_state = inputs["encoder_hidden_state"] do
275275
Nx.axis_size(encoder_hidden_state, 1)
@@ -280,7 +280,8 @@ defmodule Bumblebee.Text.Blenderbot do
280280
decoder_num_attention_heads: spec.decoder_num_attention_heads,
281281
encoder_num_attention_heads: spec.encoder_num_attention_heads,
282282
decoder_num_blocks: spec.decoder_num_blocks,
283-
encoder_sequence_length: encoder_sequence_length
283+
encoder_sequence_length: encoder_sequence_length,
284+
attention_cache_type: opts[:cache_type]
284285
)
285286
end
286287

lib/bumblebee/text/blip_text.ex

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ defmodule Bumblebee.Text.BlipText do
182182
end
183183

184184
@impl true
185-
def init_cache(spec, batch_size, max_length, inputs) do
185+
def init_cache(spec, batch_size, max_length, inputs, opts \\ []) do
186186
encoder_sequence_length =
187187
if encoder_hidden_state = inputs["encoder_hidden_state"] do
188188
Nx.axis_size(encoder_hidden_state, 1)
@@ -193,7 +193,8 @@ defmodule Bumblebee.Text.BlipText do
193193
decoder_num_attention_heads: spec.num_attention_heads,
194194
encoder_num_attention_heads: spec.num_attention_heads,
195195
decoder_num_blocks: spec.num_blocks,
196-
encoder_sequence_length: encoder_sequence_length
196+
encoder_sequence_length: encoder_sequence_length,
197+
attention_cache_type: opts[:cache_type]
197198
)
198199
end
199200

lib/bumblebee/text/gemma.ex

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,13 +173,14 @@ defmodule Bumblebee.Text.Gemma do
173173
end
174174

175175
@impl true
176-
def init_cache(spec, batch_size, max_length, _inputs) do
176+
def init_cache(spec, batch_size, max_length, _inputs, opts \\ []) do
177177
Layers.Decoder.init_cache(batch_size, max_length,
178178
hidden_size: spec.hidden_size,
179179
attention_head_size: spec.attention_head_size,
180180
decoder_num_attention_heads: spec.num_attention_heads,
181181
decoder_num_key_value_heads: spec.num_key_value_heads,
182-
decoder_num_blocks: spec.num_blocks
182+
decoder_num_blocks: spec.num_blocks,
183+
attention_cache_type: opts[:cache_type]
183184
)
184185
end
185186

lib/bumblebee/text/gemma3_text.ex

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,13 +209,14 @@ defmodule Bumblebee.Text.Gemma3Text do
209209
end
210210

211211
@impl true
212-
def init_cache(spec, batch_size, max_length, _inputs) do
212+
def init_cache(spec, batch_size, max_length, _inputs, opts \\ []) do
213213
Layers.Decoder.init_cache(batch_size, max_length,
214214
hidden_size: spec.hidden_size,
215215
attention_head_size: spec.attention_head_size,
216216
decoder_num_attention_heads: spec.num_attention_heads,
217217
decoder_num_key_value_heads: spec.num_key_value_heads,
218-
decoder_num_blocks: spec.num_blocks
218+
decoder_num_blocks: spec.num_blocks,
219+
attention_cache_type: opts[:cache_type]
219220
)
220221
end
221222

lib/bumblebee/text/generation.ex

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ defmodule Bumblebee.Text.Generation do
1212
spec :: Bumblebee.ModelSpec.t(),
1313
batch_size :: pos_integer(),
1414
max_length :: pos_integer(),
15-
inputs :: map()
15+
inputs :: map(),
16+
opts :: keyword()
1617
) :: cache()
1718

1819
@doc """
@@ -42,9 +43,10 @@ defmodule Bumblebee.Text.Generation do
4243
@doc """
4344
Initializes an opaque cache input for iterative inference.
4445
"""
45-
@spec init_cache(Bumblebee.ModelSpec.t(), pos_integer(), pos_integer(), map()) :: cache()
46-
def init_cache(%module{} = spec, batch_size, max_length, inputs) do
47-
module.init_cache(spec, batch_size, max_length, inputs)
46+
@spec init_cache(Bumblebee.ModelSpec.t(), pos_integer(), pos_integer(), map(), keyword()) ::
47+
cache()
48+
def init_cache(%module{} = spec, batch_size, max_length, inputs, opts \\ []) do
49+
module.init_cache(spec, batch_size, max_length, inputs, opts)
4850
end
4951

5052
@doc """
@@ -313,17 +315,13 @@ defmodule Bumblebee.Text.Generation do
313315
|> Map.put(prefix <> "position_ids", position_ids)
314316

315317
batch_size = Nx.axis_size(input_ids, 0)
316-
cache = init_cache(spec, batch_size, max_length, inputs)
317318

318319
output_policy = model_output_policy(model)
319-
320-
# Cast all float cache tensors to match the model output. This way
321-
# we make sure the cache we pass as input has the same types as
322-
# the updated cache returned from the model
323-
cache =
324-
Bumblebee.Utils.Nx.map(cache, fn tensor ->
325-
Axon.MixedPrecision.cast(output_policy, tensor, :output)
326-
end)
320+
# Use the compute precision as the cache type. The key/value tensors are
321+
# produced by projection layers running in compute precision, so this
322+
# matches what the model will actually return for the cache.
323+
cache_type = output_policy.compute || {:f, 32}
324+
cache = init_cache(spec, batch_size, max_length, inputs, cache_type: cache_type)
327325

328326
Map.put(inputs, "cache", cache)
329327
end

lib/bumblebee/text/gpt2.ex

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ defmodule Bumblebee.Text.Gpt2 do
278278
end
279279

280280
@impl true
281-
def init_cache(spec, batch_size, max_length, inputs) do
281+
def init_cache(spec, batch_size, max_length, inputs, opts \\ []) do
282282
encoder_sequence_length =
283283
if encoder_hidden_state = inputs["encoder_hidden_state"] do
284284
Nx.axis_size(encoder_hidden_state, 1)
@@ -289,7 +289,8 @@ defmodule Bumblebee.Text.Gpt2 do
289289
decoder_num_attention_heads: spec.num_attention_heads,
290290
encoder_num_attention_heads: spec.num_attention_heads,
291291
decoder_num_blocks: spec.num_blocks,
292-
encoder_sequence_length: encoder_sequence_length
292+
encoder_sequence_length: encoder_sequence_length,
293+
attention_cache_type: opts[:cache_type]
293294
)
294295
end
295296

0 commit comments

Comments
 (0)