Skip to content

Commit e8a4154

Browse files
psiddhclaude
andcommitted
DYNAMIC_UNBOUND: lazy KV cache allocation for on-device LLMs
Adds DYNAMIC_UNBOUND tensor support to ExecuTorch, enabling lazy KV cache allocation that defers memory to first inference instead of model load time. Export (Python): - MarkDynamicUnboundPass tags KV cache buffers as DYNAMIC_UNBOUND - SpecPropPass reads the flag and sets shape_dynamism accordingly - Memory planner skips DYNAMIC_UNBOUND tensors - emit_mutable_buffer_names auto-enabled when MarkDynamicUnboundPass detected - Export flag: --lazy_kv_cache Runtime (C++): - DynamicAllocator interface with PalDynamicAllocator (malloc-based) and TrackingDynamicAllocator (with stats) implementations - TensorImpl gains dynamic_allocator_ and capacity_bytes_ fields behind ET_DYNAMIC_ALLOCATOR_ENABLED compile guard - DYNAMIC_UNBOUND case in internal_resize_contiguous uses DynamicAllocator with 2x growth policy for amortized resizing - tensor_parser_portable.cpp: DYNAMIC_UNBOUND tensors start with capacity_bytes=0 and nullptr data (lazy allocation) - op_update_cache.cpp: maybe_resize_cache checks for null data pointer, triggers DynamicAllocator on first use - op_sdpa.cpp: same null-data guard before update_cache calls - method.cpp: FreeCall properly frees DYNAMIC_UNBOUND tensor memory - MemoryManager accepts optional DynamicAllocator* - Module::load_method creates PalDynamicAllocator when enabled - util.h: get_rss_bytes reads /proc/self/statm for current RSS Build: - CMake option EXECUTORCH_ENABLE_DYNAMIC_ALLOCATOR adds -DET_DYNAMIC_ALLOCATOR_ENABLED - All DYNAMIC_UNBOUND code guarded by #ifdef ET_DYNAMIC_ALLOCATOR_ENABLED Tested on Samsung S23 with Qwen3 0.6B (fp16) and Qwen2.5-Math 1.5B (8da4w): - Load RSS: ~100 MiB (vs ~2147 MiB without) — KV cache not pre-allocated - First inference: +1.6 GB (KV cache allocated on demand) - 10+ multi-turn conversations stable, no crashes - Generation speed unchanged (10-37 tok/s) Co-authored-by: Claude <noreply@anthropic.com>
1 parent 02bad9d commit e8a4154

25 files changed

Lines changed: 817 additions & 30 deletions

CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,10 @@ if(NOT EXECUTORCH_ENABLE_PROGRAM_VERIFICATION)
178178
add_definitions(-DET_ENABLE_PROGRAM_VERIFICATION=0)
179179
endif()
180180

181+
if(EXECUTORCH_ENABLE_DYNAMIC_ALLOCATOR)
182+
add_definitions(-DET_DYNAMIC_ALLOCATOR_ENABLED)
183+
endif()
184+
181185
if(EXECUTORCH_ENABLE_EVENT_TRACER)
182186
add_definitions(-DET_EVENT_TRACER_ENABLED)
183187
endif()

examples/models/llama/export_llama_lib.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,15 @@ def build_args_parser() -> argparse.ArgumentParser:
418418
help="maximum length of context for model to remember",
419419
)
420420

421+
parser.add_argument(
422+
"--lazy_kv_cache",
423+
action="store_true",
424+
default=False,
425+
help="Mark KV cache buffers as DYNAMIC_UNBOUND so they are allocated "
426+
"lazily at runtime instead of at load time. Reduces initial memory "
427+
"usage when max_context_length is large.",
428+
)
429+
421430
parser.add_argument(
422431
"--local_global_attention",
423432
type=parse_list_of_ints,
@@ -784,6 +793,7 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager:
784793
local_global_attention=llm_config.model.local_global_attention,
785794
use_torchao_kernels_linear=llm_config.backend.torchao.use_torchao_kernels_linear,
786795
use_torchao_kernels_tied_embedding=llm_config.backend.torchao.use_torchao_kernels_tied_embedding,
796+
lazy_kv_cache=llm_config.export.lazy_kv_cache,
787797
)
788798
)
789799

@@ -1362,6 +1372,13 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
13621372
if llm_config.base.model_class.value in TORCHTUNE_DEFINED_MODELS:
13631373
additional_passes = [InitializedMutableBufferPass(["kv_cache_pos"])]
13641374

1375+
if llm_config.export.lazy_kv_cache:
1376+
from executorch.exir.passes.mark_dynamic_unbound_pass import (
1377+
MarkDynamicUnboundPass,
1378+
)
1379+
1380+
additional_passes.append(MarkDynamicUnboundPass())
1381+
13651382
# export_to_edge
13661383
builder_manager = _prepare_for_llama_export(llm_config)
13671384
if (
@@ -1614,6 +1631,7 @@ def _get_source_transforms( # noqa
16141631
use_torchao_kernels_linear: bool = False,
16151632
use_torchao_kernels_tied_embedding: bool = False,
16161633
quantize_with_hqq: bool = True,
1634+
lazy_kv_cache: bool = False,
16171635
) -> List[Callable[[torch.nn.Module], torch.nn.Module]]:
16181636
"""
16191637
Return a list of functions that transform a graph.
@@ -1740,6 +1758,24 @@ def _get_source_transforms( # noqa
17401758

17411759
if quantize_kv_cache:
17421760
assert use_kv_cache, "quantize_kv_cache requires use_kv_cache=True"
1761+
# TODO: --lazy_kv_cache + --quantize_kv_cache is not yet
1762+
# supported. QuantizedKVCache.from_float reads the source cache shape
1763+
# to determine max_context_length; with lazy=True the shape is
1764+
# [B, 0, H, D], so it creates a permanently-zero-sized quantized cache
1765+
# that silently produces wrong results. Proper support requires:
1766+
# 1. Plumbing `lazy` into QuantizedKVCache (moderate: ~100 LOC across
1767+
# custom_kv_cache.py and quantized_sdpa.py).
1768+
# 2. Making QuantizedKVCache register its k/v buffers with seq_len=0
1769+
# and matching name patterns so MarkDynamicUnboundPass picks them up.
1770+
# 3. Ensuring the quantized update_cache C++ op also calls
1771+
# maybe_resize_cache (or equivalent) to grow before dequant+write.
1772+
# Estimated effort: a few days of work, mostly in the quantized SDPA
1773+
# kernel which packs/unpacks int4 cache values in a layout-specific way
1774+
# that complicates in-place growth.
1775+
assert not lazy_kv_cache, (
1776+
"--lazy_kv_cache and --quantize_kv_cache cannot be used together yet. "
1777+
"QuantizedKVCache does not support DYNAMIC_UNBOUND lazy allocation."
1778+
)
17431779
transforms.append(replace_kv_cache_with_quantized_kv_cache)
17441780
# Right now
17451781
transforms.append(replace_sdpa_with_quantized_sdpa)
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import List, Optional
8+
9+
from executorch.exir.pass_base import ExportPass
10+
11+
12+
class MarkDynamicUnboundPass(ExportPass):
13+
"""
14+
Marks matching placeholder nodes with ``et_dynamic_unbound`` metadata.
15+
16+
After ``SpecPropPass`` creates ``TensorSpec`` for each placeholder,
17+
``update_placeholder_tensor_specs`` reads this flag and sets the spec's
18+
``shape_dynamism`` to ``DYNAMIC_UNBOUND``. The memory planner then skips
19+
those tensors, and the runtime allocates their memory lazily via
20+
``DynamicAllocator``.
21+
22+
Typical usage: mark KV cache buffers so they start unallocated and grow
23+
on demand, avoiding the full upfront memory cost of max_context_length.
24+
"""
25+
26+
def __init__(
27+
self,
28+
name_patterns: Optional[List[str]] = None,
29+
) -> None:
30+
super().__init__()
31+
self.name_patterns = name_patterns or ["k_cache", "v_cache"]
32+
33+
def placeholder(self, name: str, arg, meta):
34+
if any(pattern in name for pattern in self.name_patterns):
35+
meta["et_dynamic_unbound"] = True
36+
return super().placeholder(name, arg, meta)

exir/passes/spec_prop_pass.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import torch
1313
from executorch.exir.delegate import executorch_call_delegate
1414
from executorch.exir.pass_base import ExportPass, ProxyValue
15+
from executorch.exir.schema import TensorShapeDynamism
1516
from executorch.exir.tensor import TensorSpec
1617
from torch.export.exported_program import ExportGraphSignature
1718
from torch.fx.node import Node
@@ -121,3 +122,7 @@ def update_placeholder_tensor_specs(
121122
in exported_program.graph_signature.inputs_to_lifted_tensor_constants
122123
):
123124
spec.const = True
125+
if isinstance(spec, TensorSpec) and node.meta.get(
126+
"et_dynamic_unbound", False
127+
):
128+
spec.shape_dynamism = TensorShapeDynamism.DYNAMIC_UNBOUND

extension/llm/custom_ops/custom_ops.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -207,16 +207,11 @@ def _validate_update_cache_params(
207207
), f"Expected value and cache to have same size in dimension {i} but got {value.size(i)} and {cache.size(i)}"
208208

209209
torch._check_is_size(start_pos)
210-
if indices is None:
210+
# Bounds checks are skipped for DYNAMIC_UNBOUND (lazy) caches where
211+
# cache.size(1) starts at 0 and grows at runtime.
212+
if indices is None and cache.size(1) > 0:
211213
torch._check(start_pos < cache.size(1))
212-
assert start_pos < cache.size(
213-
1
214-
), f"Start position {start_pos} must be less than sequence length {cache.size(1)}"
215-
216214
torch._check((start_pos + seq_len) <= cache.size(1))
217-
assert (start_pos + seq_len) <= cache.size(
218-
1
219-
), f"Start position + length = {start_pos + seq_len} must be less than sequence length {cache.size(1)}"
220215

221216
if indices is not None:
222217
assert (

extension/llm/custom_ops/op_sdpa.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -593,6 +593,14 @@ Tensor& sdpa_with_kv_cache_out(
593593

594594
ET_CHECK_MSG(q_projected.dim() == 4, "query must be a 4D tensor");
595595

596+
// For DYNAMIC_UNBOUND lazy KV cache: allocate if data is null.
597+
if (key_cache.const_data_ptr() == nullptr) {
598+
resize_tensor(key_cache, key_cache.sizes());
599+
}
600+
if (value_cache.const_data_ptr() == nullptr) {
601+
resize_tensor(value_cache, value_cache.sizes());
602+
}
603+
596604
update_cache(k_projected, key_cache, start_pos, seq_len);
597605
update_cache(v_projected, value_cache, start_pos, seq_len);
598606

extension/llm/custom_ops/op_update_cache.cpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <executorch/runtime/core/exec_aten/util/dim_order_util.h>
1212
// @lint-ignore CLANGTIDY facebook-unused-include-check
1313
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
14+
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
1415

1516
#include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h>
1617

@@ -207,6 +208,35 @@ Tensor& update_cache_impl(
207208
}
208209
} // anonymous namespace
209210

211+
// Grow cache seq dimension if needed (for DYNAMIC_UNBOUND lazy KV cache).
212+
static bool maybe_resize_cache(
213+
RuntimeContext& ctx,
214+
const Tensor& value,
215+
Tensor& cache,
216+
int64_t start_pos) {
217+
ET_CHECK_OR_RETURN_FALSE(cache.dim() == 4, "cache must be a 4D tensor");
218+
ET_CHECK_OR_RETURN_FALSE(value.dim() == 4, "value must be a 4D tensor");
219+
int64_t seq_len = value.size(1);
220+
int64_t required_seq = start_pos + seq_len;
221+
// Resize if cache is too small OR if data hasn't been allocated yet
222+
// (lazy DYNAMIC_UNBOUND cache starts with capacity_bytes=0 and null data).
223+
if (required_seq > cache.size(1) || cache.const_data_ptr() == nullptr) {
224+
int64_t new_seq = std::max(required_seq, cache.size(1));
225+
executorch::aten::SizesType new_sizes[] = {
226+
static_cast<executorch::aten::SizesType>(cache.size(0)),
227+
static_cast<executorch::aten::SizesType>(new_seq),
228+
static_cast<executorch::aten::SizesType>(cache.size(2)),
229+
static_cast<executorch::aten::SizesType>(cache.size(3)),
230+
};
231+
auto err = resize_tensor(
232+
cache, {new_sizes, static_cast<size_t>(cache.dim())});
233+
if (err != Error::Ok) {
234+
return false;
235+
}
236+
}
237+
return true;
238+
}
239+
210240
// Original update_cache_out function without indices parameter
211241
Tensor& update_cache_out(
212242
RuntimeContext& ctx,
@@ -215,6 +245,11 @@ Tensor& update_cache_out(
215245
const int64_t start_pos,
216246
Tensor& output) {
217247
int64_t seq_len = value.size(1);
248+
ET_KERNEL_CHECK(
249+
ctx,
250+
maybe_resize_cache(ctx, value, cache, start_pos),
251+
InvalidArgument,
252+
output);
218253
ET_KERNEL_CHECK(
219254
ctx,
220255
validate_cache_params(value, cache, start_pos, seq_len),
@@ -233,6 +268,11 @@ Tensor& update_cache_with_indices_out(
233268
const Tensor& indices,
234269
Tensor& output) {
235270
int64_t seq_len = value.size(1);
271+
ET_KERNEL_CHECK(
272+
ctx,
273+
maybe_resize_cache(ctx, value, cache, start_pos),
274+
InvalidArgument,
275+
output);
236276
ET_KERNEL_CHECK(
237277
ctx,
238278
validate_cache_params(value, cache, start_pos, seq_len, indices),
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
Tests for lazy KV cache (DYNAMIC_UNBOUND) support.
9+
10+
Tests the update_cache custom op's ability to handle caches that start at
11+
seq_len=0 and grow on demand, which is the foundation for pay-as-you-go
12+
KV cache memory allocation.
13+
"""
14+
15+
# pyre-unsafe
16+
17+
import unittest
18+
19+
import torch
20+
21+
from executorch.extension.llm.custom_ops import custom_ops # noqa
22+
23+
24+
class LazyKVCacheUpdateTest(unittest.TestCase):
25+
"""Test update_cache op with zero-sized initial caches (lazy KV cache)."""
26+
27+
def setUp(self):
28+
torch.manual_seed(42)
29+
self.batch_size = 1
30+
self.num_heads = 4
31+
self.head_dim = 8
32+
self.max_seq_len = 64
33+
34+
def test_update_cache_grows_from_zero(self):
35+
"""Verify update_cache works when cache seq dim starts at full size
36+
and tokens are appended sequentially."""
37+
cache = torch.zeros(
38+
(self.batch_size, self.max_seq_len, self.num_heads, self.head_dim),
39+
dtype=torch.float32,
40+
)
41+
42+
for pos in range(10):
43+
value = torch.randn(
44+
(self.batch_size, 1, self.num_heads, self.head_dim),
45+
dtype=torch.float32,
46+
)
47+
torch.ops.llama.update_cache(value, cache, pos)
48+
self.assertTrue(
49+
torch.allclose(cache[:, pos : pos + 1, :, :], value),
50+
f"Cache mismatch at position {pos}",
51+
)
52+
53+
def test_custom_kv_cache_lazy_init(self):
54+
"""Verify CustomKVCache with lazy=True creates zero-sized buffers."""
55+
from executorch.examples.models.llama.source_transformation.custom_kv_cache import (
56+
CustomKVCache,
57+
)
58+
59+
cache = CustomKVCache(
60+
max_batch_size=1,
61+
max_context_length=131072, # 128K ceiling
62+
n_heads=4,
63+
head_dim=8,
64+
dtype=torch.float32,
65+
lazy=True,
66+
)
67+
self.assertEqual(cache.k_cache.shape[1], 0, "Lazy k_cache seq dim should be 0")
68+
self.assertEqual(cache.v_cache.shape[1], 0, "Lazy v_cache seq dim should be 0")
69+
self.assertEqual(cache.max_context_length, 131072)
70+
71+
def test_custom_kv_cache_non_lazy_init(self):
72+
"""Verify CustomKVCache without lazy=True creates full-sized buffers."""
73+
from executorch.examples.models.llama.source_transformation.custom_kv_cache import (
74+
CustomKVCache,
75+
)
76+
77+
cache = CustomKVCache(
78+
max_batch_size=1,
79+
max_context_length=64,
80+
n_heads=4,
81+
head_dim=8,
82+
dtype=torch.float32,
83+
lazy=False,
84+
)
85+
self.assertEqual(cache.k_cache.shape[1], 64)
86+
self.assertEqual(cache.v_cache.shape[1], 64)
87+
88+
def test_replace_kv_cache_with_lazy(self):
89+
"""Verify replace_kv_cache_with_custom_kv_cache passes lazy flag."""
90+
from executorch.examples.models.llama.source_transformation.custom_kv_cache import (
91+
CustomKVCache,
92+
replace_kv_cache_with_custom_kv_cache,
93+
)
94+
from executorch.examples.models.llama.attention import KVCache
95+
96+
class FakeModel(torch.nn.Module):
97+
def __init__(self):
98+
super().__init__()
99+
# KVCache stores as [B, H, S, D]
100+
self.kv = KVCache(
101+
max_batch_size=1,
102+
max_context_length=128,
103+
n_heads=4,
104+
head_dim=8,
105+
enable_dynamic_shape=False,
106+
dtype=torch.float32,
107+
)
108+
109+
def forward(self, x):
110+
return x
111+
112+
model = FakeModel()
113+
replace_kv_cache_with_custom_kv_cache(model, lazy=True)
114+
self.assertIsInstance(model.kv, CustomKVCache)
115+
# CustomKVCache stores as [B, S, H, D], lazy means seq_dim=0
116+
self.assertEqual(model.kv.k_cache.shape[1], 0)
117+
118+
119+
class LazyKVCacheMetaKernelTest(unittest.TestCase):
120+
"""Test that meta kernels work without upper-bound cache size checks."""
121+
122+
def test_meta_kernel_allows_start_pos_beyond_cache(self):
123+
"""Meta kernel should not reject start_pos >= cache.size(1)."""
124+
value = torch.randn(1, 1, 4, 8)
125+
# Cache with seq_len=0 (lazy)
126+
cache = torch.zeros(1, 0, 4, 8)
127+
# This should not raise — the runtime op handles resize
128+
result = torch.ops.llama.update_cache(value, cache, 0)
129+
self.assertIsNotNone(result)
130+
131+
def test_meta_kernel_allows_large_start_pos(self):
132+
"""Meta kernel should allow start_pos beyond current cache size for lazy caches."""
133+
value = torch.randn(1, 1, 4, 8)
134+
cache = torch.zeros(1, 0, 4, 8)
135+
# Lazy cache (size(1)==0) skips bounds checks — runtime op handles resize
136+
result = torch.ops.llama.update_cache(value, cache, 100)
137+
self.assertIsNotNone(result)
138+
139+
140+
if __name__ == "__main__":
141+
unittest.main()

0 commit comments

Comments
 (0)