diff --git a/atom/model_engine/block_manager.py b/atom/model_engine/block_manager.py index d62a2c413..7e081a725 100644 --- a/atom/model_engine/block_manager.py +++ b/atom/model_engine/block_manager.py @@ -118,6 +118,15 @@ def can_allocate(self, seq: Sequence) -> bool: block_id = self.hash_to_block_id.get(h, -1) if block_id == -1 or self.blocks[block_id].token_ids != token_ids: cache_miss = True + # If the entire prompt would be cached, force the last full block + # to recompute so prefill has at least one token to forward and + # produce logits for the next-token sampler. + if ( + not cache_miss + and i == seq.num_blocks - 1 + and len(token_ids) == self.block_size + ): + cache_miss = True if cache_miss: needed_free += 1 return ( @@ -142,6 +151,16 @@ def allocate(self, seq: Sequence): ) if block_id == -1 or self.blocks[block_id].token_ids != token_ids: cache_miss = True + # If the entire prompt would be cached, force the last full block + # to recompute so prefill has at least one token to forward and + # produce logits for the next-token sampler. Must mirror the same + # condition in can_allocate() so the block budget agrees. + if ( + not cache_miss + and i == seq.num_blocks - 1 + and len(token_ids) == self.block_size + ): + cache_miss = True if cache_miss: block_id = self._pop_free_block() block = self._allocate_block(block_id) diff --git a/atom/model_ops/attention_mla.py b/atom/model_ops/attention_mla.py index 9e2faa889..c92c231c0 100644 --- a/atom/model_ops/attention_mla.py +++ b/atom/model_ops/attention_mla.py @@ -698,10 +698,6 @@ def forward_impl_server_mode( kv_cache = kv_cache_data[f"layer_{self.layer_num}"].k_cache if context.is_prefill and not use_prefill_mla: - use_prefix_cache = ( - attn_metadata.has_cached and self.kv_b_proj.weight.dtype != dtypes.fp4x2 - ) - prefill_q = self.q_proj(q, x_scale=q_scale).view( -1, self.num_heads, self.qk_head_dim ) @@ -718,7 +714,7 @@ def forward_impl_server_mode( scale=self._k_scale, ) - if use_prefix_cache: + if attn_metadata.has_cached: # k_full/v_full are used for attention compute; gather_kv_b_proj reads # fp8 from cache and dequantizes internally, so output must be model dtype k_full = torch.empty( diff --git a/tests/test_block_manager.py b/tests/test_block_manager.py index 60a6c1061..9ad059ff2 100644 --- a/tests/test_block_manager.py +++ b/tests/test_block_manager.py @@ -301,10 +301,11 @@ def test_preempt_and_reschedule_reuses_cache(self, seq_factory): assert s1.num_cached_tokens == 0 assert s1.block_table == [] - # Re-allocate — should get cache hits on both blocks + # Re-allocate — first block is a cache hit; the last full block is + # force-recomputed so prefill has at least one token to forward. s1_retry = seq_factory([1, 2, 3, 4, 5, 6, 7, 8]) bm.allocate(s1_retry) - assert s1_retry.num_cached_tokens == 8 # both blocks cached + assert s1_retry.num_cached_tokens == 4 # ── Edge cases ─────────────────────────────────────────────────────────── @@ -325,8 +326,9 @@ def test_single_token_no_cache(self, seq_factory): # Partial block → hash is -1 → no caching assert s2.num_cached_tokens == 0 - def test_exact_block_size_fully_cached(self, seq_factory): - """Sequence with exactly block_size tokens — fully cached on reuse.""" + def test_exact_block_size_last_block_recomputed(self, seq_factory): + """Single-block prompt: last full block is force-recomputed on reuse so + prefill has at least one token to forward and produce logits.""" cfg = MockConfig( num_kvcache_blocks=4, kv_cache_block_size=4, enable_prefix_caching=True ) @@ -336,7 +338,7 @@ def test_exact_block_size_fully_cached(self, seq_factory): bm.deallocate(s1) s2 = seq_factory([1, 2, 3, 4]) bm.allocate(s2) - assert s2.num_cached_tokens == 4 + assert s2.num_cached_tokens == 0 def test_free_block_ids_set_consistent(self, block_manager, seq_factory): """free_block_ids_set stays consistent through allocate/deallocate."""