Fpz/chunk prefill#740
Open
jiayyu wants to merge 6 commits into
Open
Conversation
Contributor
There was a problem hiding this comment.
Pull request overview
This PR adds chunked prefill support so prompts can be prefetched in multiple steps when constrained by max_num_batched_tokens, and updates the runtime/attention plumbing to correctly handle prefix-cache + partial-block scheduling. It also renames the “cached tokens” tracking field to num_kv_computed across the engine and updates tests accordingly.
Changes:
- Add chunked-prefill scheduling in
Scheduler, including partial-prefill tracking and forwardingnum_kv_computedintoScheduledBatch. - Extend forward/attention metadata to support partial-prefill execution and correct KV gather behavior when block tables are converted.
- Update model runner to optionally skip logits/sampling for “intermediate” prefill chunks, plus test updates and config plumbing.
Reviewed changes
Copilot reviewed 14 out of 14 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/test_scheduler.py | Updates scheduler tests for chunked prefill behavior and num_kv_computed semantics. |
| tests/test_block_manager.py | Updates block manager tests to use num_kv_computed. |
| tests/conftest.py | Adds enable_chunked_prefill to MockConfig. |
| atom/utils/forward_context.py | Adds Context.is_partial_prefill and AttentionMetaData.orig_block_tables. |
| atom/model_ops/base_attention.py | Adjusts fp8 dequant path in gather kernel and supports per_token_quant plumbing. |
| atom/model_ops/attentions/backends.py | Uses num_kv_computed and fixes slot mapping for partial-block prefills; ensures cu_seqlens buffers are copied. |
| atom/model_ops/attention_mla.py | Tweaks prefix-cache gating and weight preshuffle handling; fixes head-dim usage for V buffer. |
| atom/model_ops/attention_mha.py | Tracks cache layout for prefix gather, uses orig_block_tables, and passes per_token_quant. |
| atom/model_engine/sequence.py | Renames num_cached_tokens to num_kv_computed and removes unused cached-block helper. |
| atom/model_engine/scheduler.py | Implements chunked prefill scheduling/resume, partial-prefill bookkeeping, and batch metadata updates. |
| atom/model_engine/model_runner.py | Propagates is_partial_prefill and skips logits/sampling for intermediate chunks; fixes token indexing for deferred/new decode layout. |
| atom/model_engine/engine_core.py | Passes scheduled_batch into Scheduler.postprocess() for KV-progress updates. |
| atom/model_engine/block_manager.py | Updates prefix-cache hit accounting to increment num_kv_computed. |
| atom/config.py | Adds enable_chunked_prefill config flag (default enabled). |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Comment on lines
+568
to
+570
| if num_new_tokens_est > budget_remaining and num_batched_tokens > 0: | ||
| self.waiting.appendleft(seq) | ||
| break |
Comment on lines
+1705
to
1710
| if context.is_partial_prefill: | ||
| # B scheme: skip compute_logits for intermediate chunks | ||
| logits = None | ||
| else: | ||
| logits = self.model.compute_logits(hidden_states) | ||
| else: |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
Technical Details
Test Plan
Test Result
Submission Checklist