@@ -192,7 +192,9 @@ def prefill(
192192 )
193193 return (prefix , result_tokens )
194194
195- @functools .partial (jax .jit , static_argnums = (0 ,))
195+ @functools .partial (
196+ jax .jit , static_argnums = (0 ,), static_argnames = ("num_samples" ,)
197+ )
196198 def prefill_multisampling (
197199 self ,
198200 * ,
@@ -216,26 +218,30 @@ def prefill_multisampling(
216218 # Generate dummy prefill cache content
217219 prefill_cache = padded_tokens [None , :] * params
218220
219- # Create a dummy first generated token.
220- first_generated_token = (prefill_cache .sum (axis = - 1 ).astype (jnp .int32 ))[
221- :, jnp .newaxis
222- ]
221+ # Create dummy first generated tokens.
222+ first_generated_tokens = []
223+ for _ in range (num_samples ):
224+ first_generated_token = (prefill_cache .sum (axis = - 1 ).astype (jnp .int32 ))[
225+ :, jnp .newaxis
226+ ]
227+ first_generated_tokens .append (first_generated_token )
228+ first_generated_tokens = jnp .concatenate (first_generated_tokens , axis = 0 )
223229
224230 prefix = Prefix (
225231 logits = jax .random .normal (self ._prng_key , (1 , self .vocab_size )),
226232 cache = prefill_cache ,
227233 next_pos = jnp .full ((1 , 1 ), true_length , dtype = jnp .int32 ),
228- num_generated_tokens = jnp .zeros ((1 , 1 ), dtype = jnp .int32 ),
229- first_token = first_generated_token ,
234+ num_generated_tokens = jnp .zeros ((num_samples , 1 ), dtype = jnp .int32 ),
235+ first_token = first_generated_tokens ,
230236 )
231237
232238 speculations = first_generated_token .shape [1 ]
233239 result_tokens = engine_api .ResultTokens (
234240 data = jnp .concatenate (
235241 (
236- first_generated_token ,
237- jnp .ones_like (first_generated_token ),
238- jnp .ones_like (first_generated_token ),
242+ first_generated_tokens ,
243+ jnp .ones_like (first_generated_tokens ),
244+ jnp .ones_like (first_generated_tokens ),
239245 ),
240246 axis = - 1 ,
241247 ),
@@ -244,7 +250,7 @@ def prefill_multisampling(
244250 valid_idx = (speculations , 2 * speculations ),
245251 # And lengths is rank 1.
246252 length_idx = (2 * speculations , 2 * speculations + 1 ),
247- samples_per_slot = self . generate_cache_batch // self . prefill_cache_batch ,
253+ samples_per_slot = num_samples ,
248254 )
249255 return (prefix , result_tokens )
250256
@@ -398,21 +404,21 @@ def bulk_insert(
398404 """Insert a single computed prefill cache into multiple slots in
399405 KV cache.
400406 """
401- prefill_cache = prefix . cache
407+ prefill_cache = decode_state . prefill_cache
402408 generate_cache = decode_state .generate_cache
403409 generate_lengths = decode_state .generate_lengths
404410 generate_tokens = decode_state .generate_tokens
405411 for slot in slots :
406412 prefill_cache = jax .lax .dynamic_update_slice_in_dim (
407- decode_state . prefill_cache , prefill_cache , slot , axis = 0
413+ prefill_cache , prefix . cache , slot , axis = 0
408414 )
409415 generate_cache = jax .lax .dynamic_update_slice_in_dim (
410416 generate_cache ,
411417 jnp .zeros ((1 , self .cache_length )),
412418 slot ,
413419 axis = 0 ,
414420 )
415- samples_per_slot = self . generate_cache_batch // self . prefill_cache_batch
421+ samples_per_slot = 1
416422 generate_lengths = jax .lax .dynamic_update_slice_in_dim (
417423 generate_lengths ,
418424 jnp .ones ((samples_per_slot ), dtype = jnp .int32 ),
0 commit comments