Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 7 additions & 21 deletions jetstream/core/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,32 +730,18 @@ def _do_chunked_prefill_with_prefix_cache(
chunk_size = prefill_engine.prefill_chunk_size

# 1. Load the longest possible prefix from the cache
load_result = prefix_cache.load_existing_prefix(
self._prefix_cache, tuple_tokens, chunk_size
existing_prefix, remain_tokens = (
prefix_cache.load_existing_prefix_and_get_remain_tokens(
self._prefix_cache, tokens, chunk_size
)
)

existing_prefix = None
remain_tokens = tokens # Assume full prefill initially
original_common_prefix_len = 0

if load_result:
existing_prefix, original_common_prefix_len = load_result
# Calculate the tokens that still need to be prefilled
# common_prefix_tokens is already truncated to chunk_size multiple
# and ensures at least one token remains.
truncated_len = existing_prefix.common_prefix_tokens.shape[0]
remain_tokens = tokens[truncated_len:]
if existing_prefix is not None:
logger.debug(
"Prefix cache hit. Original common length: %d, Truncated length: %d,"
" Remaining tokens to prefill: %d",
original_common_prefix_len,
truncated_len,
"Prefix cache hit length: %d, Remaining tokens len to prefill: %d",
len(existing_prefix.common_prefix_tokens),
len(remain_tokens),
)
else:
logger.debug(
"Prefix cache miss or prefix too short. Prefilling all tokens."
)

# 2. Perform chunked prefill on the remaining tokens
prefill_result, first_token = self._do_chunked_prefill(
Expand Down
Loading