Skip to content

Commit fc0d025

Browse files
Refactor(PrefixCache): New load API, per-layer Tries, async ops & stats
Add async to prevent device_get blocking on the critical paths waiting prefill result. Use per-layer tries to prevent load cache from DRAM when common length tie. Add statistic for debug and benchmark.
1 parent 4aafd76 commit fc0d025

3 files changed

Lines changed: 731 additions & 239 deletions

File tree

jetstream/core/orchestrator.py

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -730,32 +730,18 @@ def _do_chunked_prefill_with_prefix_cache(
730730
chunk_size = prefill_engine.prefill_chunk_size
731731

732732
# 1. Load the longest possible prefix from the cache
733-
load_result = prefix_cache.load_existing_prefix(
734-
self._prefix_cache, tuple_tokens, chunk_size
733+
existing_prefix, remain_tokens = (
734+
prefix_cache.load_existing_prefix_and_get_remain_tokens(
735+
self._prefix_cache, tokens, chunk_size
736+
)
735737
)
736738

737-
existing_prefix = None
738-
remain_tokens = tokens # Assume full prefill initially
739-
original_common_prefix_len = 0
740-
741-
if load_result:
742-
existing_prefix, original_common_prefix_len = load_result
743-
# Calculate the tokens that still need to be prefilled
744-
# common_prefix_tokens is already truncated to chunk_size multiple
745-
# and ensures at least one token remains.
746-
truncated_len = existing_prefix.common_prefix_tokens.shape[0]
747-
remain_tokens = tokens[truncated_len:]
739+
if existing_prefix is not None:
748740
logger.debug(
749-
"Prefix cache hit. Original common length: %d, Truncated length: %d,"
750-
" Remaining tokens to prefill: %d",
751-
original_common_prefix_len,
752-
truncated_len,
741+
"Prefix cache hit length: %d, Remaining tokens len to prefill: %d",
742+
len(existing_prefix.common_prefix_tokens),
753743
len(remain_tokens),
754744
)
755-
else:
756-
logger.debug(
757-
"Prefix cache miss or prefix too short. Prefilling all tokens."
758-
)
759745

760746
# 2. Perform chunked prefill on the remaining tokens
761747
prefill_result, first_token = self._do_chunked_prefill(

0 commit comments

Comments
 (0)