@@ -86,6 +86,7 @@ def __init__(
8686 cache_length : int ,
8787 weight : float ,
8888 vocab_size : int = 1024 ,
89+ use_chunked_prefill : bool = False ,
8990 ):
9091 self .prefill_cache_batch = batch_size
9192 self .generate_cache_batch = batch_size
@@ -96,17 +97,24 @@ def __init__(
9697 mesh_utils .create_device_mesh ((1 , 1 , 1 ), jax .devices ()), ("x" , "y" , "z" )
9798 )
9899 self ._prng_key = jax .random .PRNGKey (42 )
100+ self ._use_chunked_prefill = use_chunked_prefill
99101
100102 def load_params (self ) -> Params :
101103 """Loads model weights."""
102104 # An integer, used to multiply inputs.
103105 return jnp .array ([self .weight ], dtype = jnp .float32 )
104106
107+ def load_params_dict (self ) -> Params :
108+ """Loads model weights."""
109+ # An integer, used to multiply inputs.
110+ return {"params" : jnp .array ([self .weight ], dtype = jnp .float32 )}
111+
105112 @functools .partial (
106113 jax .jit ,
107114 static_argnums = (0 ,),
108115 static_argnames = ("request_id" ,),
109116 )
117+ # pylint: disable=unused-argument
110118 def prefill (
111119 self ,
112120 * ,
@@ -115,6 +123,10 @@ def prefill(
115123 padded_tokens : jax .Array ,
116124 true_length : int ,
117125 request_id : Optional [uuid .UUID ] = None ,
126+ previous_chunk = None ,
127+ complete_padded_prompt = None ,
128+ complete_prompt_true_length = None ,
129+ positions = None ,
118130 ) -> Tuple [Prefix , engine_api .ResultTokens ]:
119131 """Computes a kv-cache for a new generate request.
120132
@@ -133,20 +145,33 @@ def prefill(
133145 assert padded_tokens .ndim == 1
134146
135147 # Generate dummy prefill cache content
136- prefill_cache = padded_tokens [None , :] * params
148+ if not self ._use_chunked_prefill :
149+ prefill_cache = padded_tokens [None , :] * params
150+ else :
151+ prefill_cache = padded_tokens [None , :]
137152
138153 # Create a dummy first generated token.
139154 first_generated_token = (prefill_cache .sum (axis = - 1 ).astype (jnp .int32 ))[
140155 :, jnp .newaxis
141156 ]
142157
143- prefix = Prefix (
144- logits = jax .random .normal (self ._prng_key , (1 , self .vocab_size )),
145- cache = prefill_cache ,
146- next_pos = jnp .full ((1 , 1 ), true_length , dtype = jnp .int32 ),
147- num_generated_tokens = jnp .zeros ((1 , 1 ), dtype = jnp .int32 ),
148- first_token = first_generated_token ,
149- )
158+ if not self ._use_chunked_prefill :
159+ prefix = Prefix (
160+ logits = jax .random .normal (self ._prng_key , (1 , self .vocab_size )),
161+ cache = prefill_cache ,
162+ next_pos = jnp .full ((1 , 1 ), true_length , dtype = jnp .int32 ),
163+ num_generated_tokens = jnp .zeros ((1 , 1 ), dtype = jnp .int32 ),
164+ first_token = first_generated_token ,
165+ )
166+ else :
167+ prefix = {
168+ "logits" : jax .random .normal (self ._prng_key , (1 , self .vocab_size )),
169+ "cache" : prefill_cache ,
170+ "next_pos" : jnp .full ((1 , 1 ), true_length , dtype = jnp .int32 ),
171+ "generated_tokens" : jnp .zeros ((1 , 1 ), dtype = jnp .int32 ),
172+ "tokens" : first_generated_token ,
173+ "first_token" : first_generated_token ,
174+ }
150175
151176 speculations = first_generated_token .shape [1 ]
152177 result_tokens = engine_api .ResultTokens (
@@ -319,15 +344,19 @@ def generate(
319344 )
320345 def insert (
321346 self ,
322- prefix : Prefix ,
347+ prefix : Any ,
323348 decode_state : DecodeState ,
324349 slot : int ,
325350 request_id : Optional [uuid .UUID ] = None ,
326351 ) -> DecodeState :
327352 """Adds `prefix` into `decode_state` at `slot`."""
328- prefill_cache = prefix .cache
353+ if not self ._use_chunked_prefill :
354+ prefill_cache = prefix .cache
355+ else :
356+ prefill_cache = prefix ["cache" ]
357+
329358 prefill_cache = jax .lax .dynamic_update_slice_in_dim (
330- decode_state .prefill_cache , prefill_cache , slot , axis = 0
359+ decode_state .prefill_cache , prefill_cache * 1.0 , slot , axis = 0
331360 )
332361 generate_cache = jax .lax .dynamic_update_slice_in_dim (
333362 decode_state .generate_cache ,
@@ -342,9 +371,13 @@ def insert(
342371 slot * samples_per_slot ,
343372 axis = 0 ,
344373 )
374+ if not self ._use_chunked_prefill :
375+ first_token = prefix .first_token
376+ else :
377+ first_token = prefix ["first_token" ]
345378 generate_tokens = jax .lax .dynamic_update_slice_in_dim (
346379 decode_state .generate_tokens ,
347- prefix . first_token ,
380+ first_token ,
348381 slot * samples_per_slot ,
349382 axis = 0 ,
350383 )
@@ -455,3 +488,18 @@ def mesh(self) -> jax.sharding.Mesh:
455488 def colocated_cpus (self ) -> None :
456489 """CPU devices colocated with the engine's accelerators."""
457490 raise NotImplementedError
491+
492+ @property
493+ def use_chunked_prefill (self ) -> bool :
494+ """Maximum prefill length."""
495+ return self ._use_chunked_prefill
496+
497+ @property
498+ def chunk_size (self ) -> bool :
499+ """Maximum prefill length."""
500+ return 2
501+
502+ @property
503+ def prefill_chunk_size (self ) -> int :
504+ """Maximum prefill length."""
505+ return 64
0 commit comments