Skip to content

Commit 79148ba

Browse files
[None][fix] Draft KV cache should not allocate host memory
When using one-model speculative decoding with separate draft KV cache (e.g. EAGLE3), the draft cache inherits the target's KvCacheConfig which may have a non-zero host_cache_size. This causes unnecessary host memory allocation for the draft cache. Only the target model should use host offloading since draft tokens are transient and may be rejected during verification. Fix: set host_cache_size=0 on the draft KV cache config before creating the draft KV cache manager. Signed-off-by: Shang-Pin Sheng <shang-pin@tmatehq.com>
1 parent 64b5c79 commit 79148ba

1 file changed

Lines changed: 5 additions & 0 deletions

File tree

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -695,6 +695,11 @@ def _create_one_model_draft_kv_cache_manager(
695695
# falls back to the target model's config for MTP mode.
696696
sparse_attn_config = effective_draft_config.sparse_attention_config
697697
draft_kv_config = kv_cache_config_override if kv_cache_config_override is not None else self._kv_cache_config
698+
# Draft KV cache should not allocate host memory — only the target
699+
# model uses host offloading. Zero out host_cache_size to prevent
700+
# unnecessary host memory allocation for the draft cache.
701+
draft_kv_config = draft_kv_config.model_copy(
702+
update={'host_cache_size': 0})
698703
return _create_kv_cache_manager(
699704
model_engine=None,
700705
kv_cache_manager_cls=draft_kv_cache_manager_cls,

0 commit comments

Comments
 (0)