Skip to content

Commit c323881

Browse files
authored
[https://nvbugs/6035425][fix] Fix KV cache host splitting logic (#14373)
Signed-off-by: Mike Iovine <miovine@nvidia.com>
1 parent 8e2b7b2 commit c323881

2 files changed

Lines changed: 325 additions & 115 deletions

File tree

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 133 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -925,21 +925,9 @@ def _create_one_model_draft_kv_cache_manager(
925925
is_disagg=self._is_disagg,
926926
)
927927

928-
def _split_kv_cache_budget_for_draft(self) -> Optional[KvCacheConfig]:
929-
"""Split KV cache budgets between target and draft KV caches.
930-
931-
When using KVCacheManagerV2 with a separate draft KV cache,
932-
max_gpu_total_bytes and host_cache_size each represent the total
933-
budget for both target and draft combined. This method splits both
934-
budgets proportionally based on their per-token KV cache sizes.
935-
936-
Returns a cloned KvCacheConfig for the draft, or None if no split is
937-
needed. Also modifies self._kv_cache_config in-place for the target.
938-
"""
939-
total_budget = self._kv_cache_config.max_gpu_total_bytes
940-
if total_budget is None or total_budget <= 0:
941-
return None
942-
928+
def _get_target_and_draft_cache_costs(
929+
self, ) -> Optional[tuple[CacheCost, CacheCost]]:
930+
"""Per-manager KV cache costs for target and draft layers."""
943931
total_kv = self._get_kv_size_per_token()
944932
target_kv = self._per_manager_cache_cost(
945933
self._kv_cache_manager_cls, self._model_engine.model.model_config)
@@ -949,10 +937,15 @@ def _split_kv_cache_budget_for_draft(self) -> Optional[KvCacheConfig]:
949937
intercept=total_kv.intercept - target_kv.intercept)
950938
if target_kv.slope <= 0 or draft_kv.slope <= 0:
951939
return None
940+
return target_kv, draft_kv
952941

953-
# Cover both managers' fixed costs first, then split the remaining
954-
# budget by per-token slope. With zero intercepts this reduces to the
955-
# original proportional split.
942+
def _compute_draft_budget_shares(
943+
self,
944+
total_budget: int,
945+
target_kv: CacheCost,
946+
draft_kv: CacheCost,
947+
) -> Optional[tuple[int, int]]:
948+
"""Split *total_budget* into (target_budget, draft_budget) byte shares."""
956949
intercept_total = target_kv.intercept + draft_kv.intercept
957950
slope_budget = total_budget - intercept_total
958951
if slope_budget <= 0:
@@ -965,62 +958,138 @@ def _split_kv_cache_budget_for_draft(self) -> Optional[KvCacheConfig]:
965958
draft_slope_share = int(slope_budget * draft_kv.slope / slope_total)
966959
draft_budget = draft_kv.intercept + draft_slope_share
967960
target_budget = total_budget - draft_budget
961+
return target_budget, draft_budget
962+
963+
def _split_kv_cache_budget_for_draft(
964+
self,
965+
budget_attr: str,
966+
draft_kv_cache_config: Optional[KvCacheConfig] = None,
967+
) -> Optional[KvCacheConfig]:
968+
"""Split a byte budget (attribute on ``KvCacheConfig``) between target
969+
and draft KV caches.
970+
971+
Splits the value of ``self._kv_cache_config.<budget_attr>`` using the
972+
affine target/draft cache costs, updates the target config in-place,
973+
and merges the draft share into ``draft_kv_cache_config`` (cloning the
974+
target config if needed).
975+
976+
Returns the (possibly newly created) draft config. The input
977+
``draft_kv_cache_config`` is returned unchanged when the split is not
978+
applicable (the budget is not set, or the per-manager cache costs are
979+
unavailable) — in those cases sharing ``self._kv_cache_config`` is
980+
correct.
981+
982+
The affine fixed (intercept) cost models GPU-resident state (e.g. mamba
983+
SSM state). It is only charged against ``max_gpu_total_bytes``; for any
984+
other budget (e.g. ``host_cache_size``, which is host offload memory the
985+
GPU-resident state never occupies) the intercept is dropped so the split
986+
stays proportional to the per-token cost.
987+
988+
When the split is *infeasible* (the combined fixed cost meets or exceeds
989+
the budget — only possible for ``max_gpu_total_bytes`` after the above)
990+
the shortfall is fatal: both managers need their fixed state resident in
991+
GPU memory, so the run would OOM. It raises ``ValueError`` rather than
992+
silently producing an unusable config. A defensive degrade-to-zero path
993+
for non-GPU budgets remains so the draft never silently inherits the full
994+
budget and double-allocates it.
995+
"""
996+
total_budget = getattr(self._kv_cache_config, budget_attr) or 0
997+
if total_budget <= 0:
998+
return draft_kv_cache_config
999+
1000+
cache_costs = self._get_target_and_draft_cache_costs()
1001+
if cache_costs is None:
1002+
return draft_kv_cache_config
1003+
target_kv, draft_kv = cache_costs
1004+
1005+
# The fixed (intercept) cost models GPU-resident state such as mamba SSM
1006+
# state; it does not consume host offload memory. When splitting a
1007+
# non-GPU budget (e.g. host_cache_size), drop the intercept so the split
1008+
# stays proportional to the per-token (slope) cost instead of being
1009+
# spuriously starved by a GPU-only fixed cost.
1010+
if budget_attr != "max_gpu_total_bytes":
1011+
target_kv = CacheCost(slope=target_kv.slope)
1012+
draft_kv = CacheCost(slope=draft_kv.slope)
1013+
1014+
shares = self._compute_draft_budget_shares(total_budget, target_kv,
1015+
draft_kv)
1016+
if shares is None:
1017+
# The split is infeasible (combined fixed cost >= total budget).
1018+
intercept_total = target_kv.intercept + draft_kv.intercept
1019+
if budget_attr == "max_gpu_total_bytes":
1020+
# A GPU budget that cannot even fit the combined fixed cost is
1021+
# fatal: both managers need their fixed state resident in GPU
1022+
# memory, so the run would OOM. Fail fast with actionable
1023+
# guidance rather than producing an unusable zero-budget draft.
1024+
raise ValueError(
1025+
f"KV cache GPU budget ({total_budget / GB:.2f} GiB) is "
1026+
f"smaller than the combined fixed cost "
1027+
f"({intercept_total / GB:.2f} GiB, e.g. mamba SSM state) "
1028+
f"for target+draft. Increase free_gpu_memory_fraction or "
1029+
f"max_gpu_total_bytes, or reduce max_batch_size (the fixed "
1030+
f"cost scales with batch size).")
1031+
# Defensive: non-GPU budgets zero out the intercept above, so with a
1032+
# positive budget this branch is currently unreachable for them. It
1033+
# remains as a safety net guaranteeing that, should a non-GPU budget
1034+
# ever carry a fixed cost it cannot fit, we degrade gracefully rather
1035+
# than letting both managers inherit the full budget and
1036+
# double-allocate it: keep the full budget on the target and zero the
1037+
# draft's share for this attribute.
1038+
logger.warning(
1039+
f"Cannot split KV cache {budget_attr} between target and draft; "
1040+
f"assigning the draft a zero {budget_attr} budget to avoid "
1041+
f"double-allocating the full budget.")
1042+
if draft_kv_cache_config is None:
1043+
draft_kv_cache_config = self._kv_cache_config.model_copy()
1044+
setattr(draft_kv_cache_config, budget_attr, 0)
1045+
return draft_kv_cache_config
1046+
target_budget, draft_budget = shares
9681047

9691048
logger.info(
970-
f"Splitting KV cache budget: total={total_budget / GB:.2f} GiB, "
1049+
f"Splitting KV cache {budget_attr}: total={total_budget / GB:.2f} GiB, "
9711050
f"target={target_budget / GB:.2f} GiB ({target_kv}), "
9721051
f"draft={draft_budget / GB:.2f} GiB ({draft_kv})")
9731052

974-
self._kv_cache_config.max_gpu_total_bytes = target_budget
975-
976-
draft_kv_cache_config = self._kv_cache_config.model_copy()
977-
draft_kv_cache_config.max_gpu_total_bytes = draft_budget
978-
979-
host_budget = self._kv_cache_config.host_cache_size
980-
if host_budget is not None and host_budget > 0:
981-
draft_ratio = draft_budget / total_budget
982-
draft_host_budget = int(host_budget * draft_ratio)
983-
target_host_budget = host_budget - draft_host_budget
984-
self._kv_cache_config.host_cache_size = target_host_budget
985-
draft_kv_cache_config.host_cache_size = draft_host_budget
986-
logger.info(
987-
f"Splitting KV cache host budget: total={host_budget / GB:.2f} GiB, "
988-
f"target={target_host_budget / GB:.2f} GiB, "
989-
f"draft={draft_host_budget / GB:.2f} GiB")
990-
1053+
setattr(self._kv_cache_config, budget_attr, target_budget)
1054+
if draft_kv_cache_config is None:
1055+
draft_kv_cache_config = self._kv_cache_config.model_copy()
1056+
setattr(draft_kv_cache_config, budget_attr, draft_budget)
9911057
return draft_kv_cache_config
9921058

1059+
def _needs_gpu_kv_cache_budget_split(self) -> bool:
1060+
"""Whether max_gpu_total_bytes must be split per manager."""
1061+
if issubclass(self._kv_cache_manager_cls, KVCacheManagerV2):
1062+
return self._should_create_separate_draft_kv_cache()
1063+
return is_vswa_enabled(self._kv_cache_config)
1064+
9931065
def build_managers(self,
9941066
resources: Dict,
9951067
estimating_kv_cache: bool = False) -> None:
9961068
"""Construct KV caches for model and draft model (if applicable)."""
9971069
if self._skip_est:
9981070
self.configure_kv_cache_capacity()
9991071

1000-
# For V2 with separate one-model draft KV cache, split the total budget
1001-
# between target and draft before creating either manager.
1002-
# Only split for the final managers, not during estimation — estimation
1003-
# uses max_tokens-based logic and must not have its config mutated.
1004-
# Two-model draft is excluded: V2 does not support two-model mode.
1005-
draft_kv_cache_config = None
1006-
if (not estimating_kv_cache
1007-
and self._should_create_separate_draft_kv_cache()
1008-
and issubclass(self._kv_cache_manager_cls, KVCacheManagerV2)):
1009-
draft_kv_cache_config = self._split_kv_cache_budget_for_draft()
1010-
1011-
# Also split for V1 VSWA. The VSWA pool is sized directly from
1012-
# max_gpu_total_bytes and ignores max_tokens, so without splitting
1013-
# both target and draft each allocate the full combined budget.
1014-
# V1 non-VSWA does not need this: max_tokens caps the block count
1015-
# per model, giving each a proportional share of the budget.
1072+
# Split combined KV cache budgets before creating managers. Skip during
1073+
# estimation — estimation uses max_tokens-based logic and must not
1074+
# mutate the config.
10161075
has_draft = (
10171076
self._draft_model_engine is not None # two-model
10181077
or self._should_create_separate_draft_kv_cache()) # one-model
1019-
if (not estimating_kv_cache and has_draft
1020-
and draft_kv_cache_config is None
1021-
and not issubclass(self._kv_cache_manager_cls, KVCacheManagerV2)
1022-
and is_vswa_enabled(self._kv_cache_config)):
1023-
draft_kv_cache_config = self._split_kv_cache_budget_for_draft()
1078+
draft_kv_cache_config = None
1079+
if not estimating_kv_cache and has_draft:
1080+
# Used when each manager sizes pools from max_gpu_total_bytes (V2
1081+
# and V1 VSWA). V1 non-VSWA GPU uses shared max_tokens instead.
1082+
if self._needs_gpu_kv_cache_budget_split():
1083+
draft_kv_cache_config = self._split_kv_cache_budget_for_draft(
1084+
"max_gpu_total_bytes", draft_kv_cache_config)
1085+
# KVCacheManagerV2 does not support two-model draft budget splitting.
1086+
v2_two_model = (issubclass(self._kv_cache_manager_cls,
1087+
KVCacheManagerV2)
1088+
and self._draft_model_engine is not None)
1089+
if not v2_two_model:
1090+
# Each manager sizes its host pool from host_cache_size directly.
1091+
draft_kv_cache_config = self._split_kv_cache_budget_for_draft(
1092+
"host_cache_size", draft_kv_cache_config)
10241093

10251094
kv_cache_manager = self._create_kv_cache_manager(
10261095
self._model_engine, estimating_kv_cache)
@@ -1037,14 +1106,19 @@ def build_managers(self,
10371106
assert draft_kv_cache_config is None, (
10381107
"KVCacheManagerV2 does not support two-model speculative "
10391108
"decoding with separate draft KV cache budget splitting.")
1040-
# For V1 VSWA, apply the draft's split budget temporarily
1109+
# For V1, apply the draft's split budgets temporarily.
10411110
if draft_kv_cache_config is not None:
10421111
saved_budget = self._kv_cache_config.max_gpu_total_bytes
1043-
self._kv_cache_config.max_gpu_total_bytes = draft_kv_cache_config.max_gpu_total_bytes
1112+
saved_host = self._kv_cache_config.host_cache_size
1113+
self._kv_cache_config.max_gpu_total_bytes = (
1114+
draft_kv_cache_config.max_gpu_total_bytes)
1115+
self._kv_cache_config.host_cache_size = (
1116+
draft_kv_cache_config.host_cache_size)
10441117
draft_kv_cache_manager = self._create_kv_cache_manager(
10451118
self._draft_model_engine, estimating_kv_cache)
10461119
if draft_kv_cache_config is not None:
10471120
self._kv_cache_config.max_gpu_total_bytes = saved_budget
1121+
self._kv_cache_config.host_cache_size = saved_host
10481122
# One-model speculative decoding with different KV layouts
10491123
elif self._should_create_separate_draft_kv_cache():
10501124
draft_kv_cache_manager = self._create_one_model_draft_kv_cache_manager(

0 commit comments

Comments
 (0)