@@ -72,6 +72,8 @@ class EngineConfig:
7272 top_p : float = 0.8
7373 top_k : int = 1
7474 enable_graph : bool = False
75+ enable_chunk_prefill_graph : bool = False
76+ chunk_size : int = 0
7577 attn_backend : str = "default"
7678 skip_load : bool = False
7779
@@ -91,6 +93,7 @@ def __init__(self, config: EngineConfig):
9193 device = self .device ,
9294 distributed_config = DistConfig (config .tensor_parallel_size ),
9395 enable_graph_compiling = config .enable_graph ,
96+ enable_chunk_prefill_graph = config .enable_chunk_prefill_graph ,
9497 attention_backend = config .attn_backend ,
9598 )
9699
@@ -167,6 +170,8 @@ def _init_device(self):
167170
168171 def add_request (self , request : InferenceRequest ):
169172 """Add a request to the scheduler."""
173+ if self .cache_type == "paged" and self .config .chunk_size > 0 :
174+ request .chunk_size = self .config .chunk_size
170175 self .scheduler .add_request (request )
171176
172177 def step (self ) -> tuple [list [InferenceRequest ], list [tuple ]]:
@@ -210,14 +215,39 @@ def _update_requests(
210215 sampled_tokens : List [int ],
211216 ) -> List [tuple ]:
212217 """Update request status after inference step."""
213- if is_prefill :
218+ # Detect a chunked-prefill mid-step: single request, prefill phase,
219+ # and this chunk does not yet cover the whole prompt. In that case
220+ # we must NOT consume a sampled token, NOT commit prefill blocks,
221+ # and re-enqueue the request to keep chunking.
222+ chunk_mid_step = (
223+ is_prefill
224+ and len (requests ) == 1
225+ and requests [0 ].is_chunking ()
226+ and not requests [0 ].chunk_is_last ()
227+ )
228+
229+ if is_prefill and not chunk_mid_step :
214230 match self .cache_type :
215231 case "paged" :
216232 self .scheduler .cache_manager .reset_req_blocks ()
217233 case "static" :
218234 self .scheduler .update_cache ()
219235 case _:
220236 raise ValueError (f"Unsupported cache_type: { self .cache_type } " )
237+
238+ if chunk_mid_step :
239+ req = requests [0 ]
240+ req .chunk_prefill_offset += req .chunk_size
241+ # If this request was aborted while chunking, drop it.
242+ if req .is_aborted ():
243+ logger .info (
244+ f"Request { req .request_id } aborted by client during chunked-prefill"
245+ )
246+ return []
247+ # Re-enqueue to keep producing chunks; no token sampled yet.
248+ self .scheduler .requeue_chunking (req )
249+ return []
250+
221251 pending = []
222252 for req , token_id in zip (requests , sampled_tokens ):
223253 if req .is_aborted ():
@@ -227,6 +257,10 @@ def _update_requests(
227257 continue
228258
229259 if req .is_prefill :
260+ # Clean up chunked-prefill state on the final chunk so the
261+ # next forward pass on this request takes the decode path.
262+ req .chunk_prefill_offset = 0
263+ req .chunk_size = 0
230264 req .is_prefill = False
231265
232266 req .generated_token_ids .append (token_id )
@@ -361,6 +395,8 @@ def __init__(
361395 top_p : float = 0.8 ,
362396 top_k : int = 1 ,
363397 enable_graph : bool = False ,
398+ enable_chunk_prefill_graph : bool = False ,
399+ chunk_size : int = 0 ,
364400 attn_backend : str = "default" ,
365401 skip_load : bool = False ,
366402 ):
@@ -398,6 +434,8 @@ def __init__(
398434 top_p = top_p ,
399435 top_k = top_k ,
400436 enable_graph = enable_graph ,
437+ enable_chunk_prefill_graph = enable_chunk_prefill_graph ,
438+ chunk_size = chunk_size ,
401439 attn_backend = attn_backend ,
402440 skip_load = skip_load ,
403441 )
@@ -539,6 +577,8 @@ def __init__(
539577 top_p : float = 0.8 ,
540578 top_k : int = 1 ,
541579 enable_graph : bool = False ,
580+ enable_chunk_prefill_graph : bool = False ,
581+ chunk_size : int = 0 ,
542582 attn_backend : str = "default" ,
543583 ):
544584 """Initialize AsyncLLMEngine.
@@ -575,6 +615,8 @@ def __init__(
575615 top_p = top_p ,
576616 top_k = top_k ,
577617 enable_graph = enable_graph ,
618+ enable_chunk_prefill_graph = enable_chunk_prefill_graph ,
619+ chunk_size = chunk_size ,
578620 attn_backend = attn_backend ,
579621 )
580622 self .engine = LLMEngine (config )
0 commit comments