Skip to content

Commit 501ba12

Browse files
navsudfacebook-github-bot
authored andcommitted
Allow chunked prefill when num_prompt_tokens > max_seq_len
Summary: Remove the early `num_prompt_tokens <= max_seq_len` check in TextLLMRunner. `TextPrefiller::prefill()` already supports chunked prefill — when the prompt is longer than `max_seq_len` it splits the input into `max_seq_len`-sized chunks and prefills them sequentially. The previous check rejected this valid case, breaking models exported with `max_seq_len < max_context_len` (e.g. a 1024 prefill chunk over a 4096 KV cache). The total-capacity bound is preserved: - For non-sliding-window models (`max_seq_len >= max_context_len`), the existing `pos_ + num_prompt_tokens < max_context_len` check is unchanged. - For sliding-window models (`max_seq_len < max_context_len`), a new per-call check `num_prompt_tokens < max_context_len` ensures the prompt itself fits in KV cache; `pos_` doesn't represent consumed capacity for these models since the model handles position wrapping internally. Differential Revision: D101728720
1 parent 32702ac commit 501ba12

1 file changed

Lines changed: 19 additions & 10 deletions

File tree

extension/llm/runner/text_llm_runner.cpp

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -138,16 +138,16 @@ Error TextLLMRunner::generate(
138138
num_prompt_tokens >= 1,
139139
InvalidArgument,
140140
"Expected at least 1 prompt token");
141-
ET_CHECK_OR_RETURN_ERROR(
142-
num_prompt_tokens <= max_seq_len,
143-
InvalidArgument,
144-
"num_prompt_tokens %d > max_seq_len %" PRId64
145-
", Single prefill chunk too large - please reduce prompt size or increase max_seq_len",
146-
num_prompt_tokens,
147-
max_seq_len);
148-
// For non-sliding-window models, also check that we won't exceed
149-
// KV cache capacity. Sliding window models (where max_seq_len <
150-
// max_context_len) handle position wrapping internally.
141+
// Note: We intentionally do NOT enforce num_prompt_tokens <= max_seq_len
142+
// here. TextPrefiller::prefill() supports chunked prefill: when
143+
// num_prompt_tokens > max_seq_len it splits the prompt into max_seq_len
144+
// chunks and prefills them sequentially. Models that were exported with
145+
// max_seq_len < max_context_len (e.g. a 1024 prefill chunk over a 4096 KV
146+
// cache) rely on this behavior.
147+
// Ensure the prompt fits within total KV cache capacity. For
148+
// sliding-window models (where max_seq_len < max_context_len) the model
149+
// handles position wrapping internally, so pos_ doesn't represent
150+
// consumed capacity and we only need a per-call bound.
151151
if (max_seq_len >= max_context_len) {
152152
ET_CHECK_OR_RETURN_ERROR(
153153
pos_ + num_prompt_tokens < max_context_len,
@@ -158,6 +158,15 @@ Error TextLLMRunner::generate(
158158
pos_,
159159
num_prompt_tokens,
160160
max_context_len);
161+
} else {
162+
ET_CHECK_OR_RETURN_ERROR(
163+
num_prompt_tokens < max_context_len,
164+
InvalidArgument,
165+
"num_prompt_tokens %d >= max_context_len %" PRId64
166+
", Prompt exceeds KV cache capacity - please reduce prompt size or "
167+
"increase max_context_len in your export script",
168+
num_prompt_tokens,
169+
max_context_len);
161170
}
162171

163172
// print prompts

0 commit comments

Comments
 (0)