Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions atom/model_engine/block_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,15 @@ def can_allocate(self, seq: Sequence) -> bool:
block_id = self.hash_to_block_id.get(h, -1)
if block_id == -1 or self.blocks[block_id].token_ids != token_ids:
cache_miss = True
# If the entire prompt would be cached, force the last full block
# to recompute so prefill has at least one token to forward and
# produce logits for the next-token sampler.
if (
not cache_miss
and i == seq.num_blocks - 1
and len(token_ids) == self.block_size
):
cache_miss = True
if cache_miss:
needed_free += 1
return (
Expand All @@ -142,6 +151,16 @@ def allocate(self, seq: Sequence):
)
if block_id == -1 or self.blocks[block_id].token_ids != token_ids:
cache_miss = True
# If the entire prompt would be cached, force the last full block
# to recompute so prefill has at least one token to forward and
# produce logits for the next-token sampler. Must mirror the same
# condition in can_allocate() so the block budget agrees.
if (
not cache_miss
and i == seq.num_blocks - 1
and len(token_ids) == self.block_size
):
cache_miss = True
if cache_miss:
block_id = self._pop_free_block()
block = self._allocate_block(block_id)
Expand Down
6 changes: 1 addition & 5 deletions atom/model_ops/attention_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,10 +698,6 @@ def forward_impl_server_mode(
kv_cache = kv_cache_data[f"layer_{self.layer_num}"].k_cache

if context.is_prefill and not use_prefill_mla:
use_prefix_cache = (
attn_metadata.has_cached and self.kv_b_proj.weight.dtype != dtypes.fp4x2
)

prefill_q = self.q_proj(q, x_scale=q_scale).view(
-1, self.num_heads, self.qk_head_dim
)
Expand All @@ -718,7 +714,7 @@ def forward_impl_server_mode(
scale=self._k_scale,
)

if use_prefix_cache:
if attn_metadata.has_cached:
# k_full/v_full are used for attention compute; gather_kv_b_proj reads
# fp8 from cache and dequantizes internally, so output must be model dtype
k_full = torch.empty(
Expand Down
12 changes: 7 additions & 5 deletions tests/test_block_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,10 +301,11 @@ def test_preempt_and_reschedule_reuses_cache(self, seq_factory):
assert s1.num_cached_tokens == 0
assert s1.block_table == []

# Re-allocate — should get cache hits on both blocks
# Re-allocate — first block is a cache hit; the last full block is
# force-recomputed so prefill has at least one token to forward.
s1_retry = seq_factory([1, 2, 3, 4, 5, 6, 7, 8])
bm.allocate(s1_retry)
assert s1_retry.num_cached_tokens == 8 # both blocks cached
assert s1_retry.num_cached_tokens == 4


# ── Edge cases ───────────────────────────────────────────────────────────
Expand All @@ -325,8 +326,9 @@ def test_single_token_no_cache(self, seq_factory):
# Partial block → hash is -1 → no caching
assert s2.num_cached_tokens == 0

def test_exact_block_size_fully_cached(self, seq_factory):
"""Sequence with exactly block_size tokens — fully cached on reuse."""
def test_exact_block_size_last_block_recomputed(self, seq_factory):
"""Single-block prompt: last full block is force-recomputed on reuse so
prefill has at least one token to forward and produce logits."""
cfg = MockConfig(
num_kvcache_blocks=4, kv_cache_block_size=4, enable_prefix_caching=True
)
Expand All @@ -336,7 +338,7 @@ def test_exact_block_size_fully_cached(self, seq_factory):
bm.deallocate(s1)
s2 = seq_factory([1, 2, 3, 4])
bm.allocate(s2)
assert s2.num_cached_tokens == 4
assert s2.num_cached_tokens == 0

def test_free_block_ids_set_consistent(self, block_manager, seq_factory):
"""free_block_ids_set stays consistent through allocate/deallocate."""
Expand Down
Loading