Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,9 @@ endif()

# Disable the deprecated constant_buffer path.
add_definitions(-DET_ENABLE_DEPRECATED_CONSTANT_BUFFER=0)
if(EXECUTORCH_ENABLE_DYNAMIC_ALLOCATOR)
add_definitions(-DET_DYNAMIC_ALLOCATOR_ENABLED)
endif()

if(EXECUTORCH_ENABLE_EVENT_TRACER)
add_definitions(-DET_EVENT_TRACER_ENABLED)
Expand Down
35 changes: 35 additions & 0 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,15 @@ def build_args_parser() -> argparse.ArgumentParser:
help="maximum length of context for model to remember",
)

parser.add_argument(
"--lazy_kv_cache",
action="store_true",
default=False,
help="Mark KV cache buffers as DYNAMIC_UNBOUND so they are allocated "
"lazily at runtime instead of at load time. Reduces initial memory "
"usage when max_context_length is large.",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this because we do actually touch the full memory during attention?

Copy link
Copy Markdown
Contributor Author

@psiddh psiddh May 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missed out on this comment...

Yes, the full max_context_length buffer is allocated on first inference, not at load time. This defers the KV cache allocation from Module.load() to the first generate() call.

Sharing some Test Results:
Concrete KV cache costs for Qwen3-0.6B (28 layers, 8 KV heads, 128 head_dim,fp16):

▎ | max_context_length | KV Cache | Without PR | With PR (at load) |
▎ |-------------------- |----------|---------------|-------------------|
▎ | 128 (default) | 14 MB | Pre-allocated | 0 MB |
▎ | 1024 | 115 MB | Pre-allocated | 0 MB |
▎ | 2048 (standard) | 229 MB | Pre-allocated | 0 MB |
▎ | 4096 | 459 MB | Pre-allocated | 0 MB |
▎ | 16384 | 1.8 GB | OOM at load | 0 MB |

Note : KV cache sizes above are for fp16. fp32 doubles these values

With this PR I increased max_context_length to 4096 on Samsung S23 (8GB RAM) and tested 10+ multi-turn conversations with stable RSS:

  • Load RSS: ~100-120 MiB (no KV cache)
  • First inference RSS: ~1730 MiB (KV cache allocated on demand)
  • Subsequent turns: stable, no memory growth

Key benefits:

  1. Lower RSS at startup → survives Android LMKD longer
  2. DynamicAllocator::free() enables freeing cache on memory pressure
    (onTrimMemory) // Future enhacements
  3. Unlocks larger context lengths (4K-16K) that would have OOM'd at load time without
    this feature / lazy allocation

)

parser.add_argument(
"--local_global_attention",
type=parse_list_of_ints,
Expand Down Expand Up @@ -808,6 +817,7 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager:
use_torchao_kernels_linear=llm_config.backend.torchao.use_torchao_kernels_linear,
use_torchao_kernels_tied_embedding=llm_config.backend.torchao.use_torchao_kernels_tied_embedding,
quantize_with_hqq=llm_config.quantization.use_hqq,
lazy_kv_cache=llm_config.export.lazy_kv_cache,
)
)

Expand Down Expand Up @@ -1430,6 +1440,12 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
# which corrupts the causal mask computation.
if llm_config.model.use_attention_sink:
additional_passes.append(InitializedMutableBufferPass(["cache_positions"]))
if llm_config.export.lazy_kv_cache:
from executorch.exir.passes.mark_dynamic_unbound_pass import (
MarkDynamicUnboundPass,
)

additional_passes.append(MarkDynamicUnboundPass())

# export_to_edge
builder_manager = _prepare_for_llama_export(llm_config)
Expand Down Expand Up @@ -1699,6 +1715,7 @@ def _get_source_transforms( # noqa
use_torchao_kernels_linear: bool = False,
use_torchao_kernels_tied_embedding: bool = False,
quantize_with_hqq: bool = True,
lazy_kv_cache: bool = False,
) -> List[Callable[[torch.nn.Module], torch.nn.Module]]:
"""
Return a list of functions that transform a graph.
Expand Down Expand Up @@ -1825,6 +1842,24 @@ def _get_source_transforms( # noqa

if quantize_kv_cache:
assert use_kv_cache, "quantize_kv_cache requires use_kv_cache=True"
# TODO: --lazy_kv_cache + --quantize_kv_cache is not yet
# supported. QuantizedKVCache.from_float reads the source cache shape
# to determine max_context_length; with lazy=True the shape is
# [B, 0, H, D], so it creates a permanently-zero-sized quantized cache
# that silently produces wrong results. Proper support requires:
# 1. Plumbing `lazy` into QuantizedKVCache (moderate: ~100 LOC across
# custom_kv_cache.py and quantized_sdpa.py).
# 2. Making QuantizedKVCache register its k/v buffers with seq_len=0
# and matching name patterns so MarkDynamicUnboundPass picks them up.
# 3. Ensuring the quantized update_cache C++ op also calls
# maybe_resize_cache (or equivalent) to grow before dequant+write.
# Estimated effort: a few days of work, mostly in the quantized SDPA
# kernel which packs/unpacks int4 cache values in a layout-specific way
# that complicates in-place growth.
assert not lazy_kv_cache, (
"--lazy_kv_cache and --quantize_kv_cache cannot be used together yet. "
"QuantizedKVCache does not support DYNAMIC_UNBOUND lazy allocation."
)
transforms.append(replace_kv_cache_with_quantized_kv_cache)
# Right now
transforms.append(replace_sdpa_with_quantized_sdpa)
Expand Down
36 changes: 36 additions & 0 deletions exir/passes/mark_dynamic_unbound_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import List, Optional

from executorch.exir.pass_base import ExportPass


class MarkDynamicUnboundPass(ExportPass):
"""
Marks matching placeholder nodes with ``et_dynamic_unbound`` metadata.

After ``SpecPropPass`` creates ``TensorSpec`` for each placeholder,
``update_placeholder_tensor_specs`` reads this flag and sets the spec's
``shape_dynamism`` to ``DYNAMIC_UNBOUND``. The memory planner then skips
those tensors, and the runtime allocates their memory lazily via
``DynamicAllocator``.

Typical usage: mark KV cache buffers so they start unallocated and grow
on demand, avoiding the full upfront memory cost of max_context_length.
"""

def __init__(
self,
name_patterns: Optional[List[str]] = None,
) -> None:
super().__init__()
self.name_patterns = name_patterns or ["k_cache", "v_cache"]

def placeholder(self, name: str, arg, meta):
if any(pattern in name for pattern in self.name_patterns):
meta["et_dynamic_unbound"] = True
return super().placeholder(name, arg, meta)
5 changes: 5 additions & 0 deletions exir/passes/spec_prop_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from executorch.exir.delegate import executorch_call_delegate
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, ProxyValue
from executorch.exir.schema import TensorShapeDynamism
from executorch.exir.tensor import TensorSpec
from torch.export.exported_program import ExportGraphSignature
from torch.fx.node import Node
Expand Down Expand Up @@ -130,3 +131,7 @@ def update_placeholder_tensor_specs(
in exported_program.graph_signature.inputs_to_lifted_tensor_constants
):
spec.const = True
if isinstance(spec, TensorSpec) and node.meta.get(
"et_dynamic_unbound", False
):
spec.shape_dynamism = TensorShapeDynamism.DYNAMIC_UNBOUND
11 changes: 3 additions & 8 deletions extension/llm/custom_ops/custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,16 +207,11 @@ def _validate_update_cache_params(
), f"Expected value and cache to have same size in dimension {i} but got {value.size(i)} and {cache.size(i)}"

torch._check_is_size(start_pos)
if indices is None:
# Bounds checks are skipped for DYNAMIC_UNBOUND (lazy) caches where
# cache.size(1) starts at 0 and grows at runtime.
if indices is None and cache.size(1) > 0:
torch._check(start_pos < cache.size(1))
assert start_pos < cache.size(
1
), f"Start position {start_pos} must be less than sequence length {cache.size(1)}"

torch._check((start_pos + seq_len) <= cache.size(1))
assert (start_pos + seq_len) <= cache.size(
1
), f"Start position + length = {start_pos + seq_len} must be less than sequence length {cache.size(1)}"

if indices is not None:
assert (
Expand Down
8 changes: 8 additions & 0 deletions extension/llm/custom_ops/op_sdpa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,14 @@ Tensor& sdpa_with_kv_cache_out(

ET_CHECK_MSG(q_projected.dim() == 4, "query must be a 4D tensor");

// For DYNAMIC_UNBOUND lazy KV cache: allocate if data is null.
if (key_cache.const_data_ptr() == nullptr) {
resize_tensor(key_cache, key_cache.sizes());
}
if (value_cache.const_data_ptr() == nullptr) {
resize_tensor(value_cache, value_cache.sizes());
}

update_cache(k_projected, key_cache, start_pos, seq_len);
update_cache(v_projected, value_cache, start_pos, seq_len);

Expand Down
40 changes: 40 additions & 0 deletions extension/llm/custom_ops/op_update_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <executorch/runtime/core/exec_aten/util/dim_order_util.h>
// @lint-ignore CLANGTIDY facebook-unused-include-check
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>

#include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h>

Expand Down Expand Up @@ -207,6 +208,35 @@ Tensor& update_cache_impl(
}
} // anonymous namespace

// Grow cache seq dimension if needed (for DYNAMIC_UNBOUND lazy KV cache).
static bool maybe_resize_cache(
RuntimeContext& ctx,
const Tensor& value,
Tensor& cache,
int64_t start_pos) {
ET_CHECK_OR_RETURN_FALSE(cache.dim() == 4, "cache must be a 4D tensor");
ET_CHECK_OR_RETURN_FALSE(value.dim() == 4, "value must be a 4D tensor");
int64_t seq_len = value.size(1);
int64_t required_seq = start_pos + seq_len;
// Resize if cache is too small OR if data hasn't been allocated yet
// (lazy DYNAMIC_UNBOUND cache starts with capacity_bytes=0 and null data).
if (required_seq > cache.size(1) || cache.const_data_ptr() == nullptr) {
int64_t new_seq = std::max(required_seq, cache.size(1));
executorch::aten::SizesType new_sizes[] = {
static_cast<executorch::aten::SizesType>(cache.size(0)),
static_cast<executorch::aten::SizesType>(new_seq),
static_cast<executorch::aten::SizesType>(cache.size(2)),
static_cast<executorch::aten::SizesType>(cache.size(3)),
};
auto err =
resize_tensor(cache, {new_sizes, static_cast<size_t>(cache.dim())});
if (err != Error::Ok) {
return false;
}
}
return true;
}

// Original update_cache_out function without indices parameter
Tensor& update_cache_out(
RuntimeContext& ctx,
Expand All @@ -215,6 +245,11 @@ Tensor& update_cache_out(
const int64_t start_pos,
Tensor& output) {
int64_t seq_len = value.size(1);
ET_KERNEL_CHECK(
ctx,
maybe_resize_cache(ctx, value, cache, start_pos),
InvalidArgument,
output);
ET_KERNEL_CHECK(
ctx,
validate_cache_params(value, cache, start_pos, seq_len),
Expand All @@ -233,6 +268,11 @@ Tensor& update_cache_with_indices_out(
const Tensor& indices,
Tensor& output) {
int64_t seq_len = value.size(1);
ET_KERNEL_CHECK(
ctx,
maybe_resize_cache(ctx, value, cache, start_pos),
InvalidArgument,
output);
ET_KERNEL_CHECK(
ctx,
validate_cache_params(value, cache, start_pos, seq_len, indices),
Expand Down
141 changes: 141 additions & 0 deletions extension/llm/custom_ops/test_lazy_kv_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
Tests for lazy KV cache (DYNAMIC_UNBOUND) support.

Tests the update_cache custom op's ability to handle caches that start at
seq_len=0 and grow on demand, which is the foundation for pay-as-you-go
KV cache memory allocation.
"""

# pyre-unsafe

import unittest

import torch

from executorch.extension.llm.custom_ops import custom_ops # noqa


class LazyKVCacheUpdateTest(unittest.TestCase):
"""Test update_cache op with zero-sized initial caches (lazy KV cache)."""

def setUp(self):
torch.manual_seed(42)
self.batch_size = 1
self.num_heads = 4
self.head_dim = 8
self.max_seq_len = 64

def test_update_cache_grows_from_zero(self):
"""Verify update_cache works when cache seq dim starts at full size
and tokens are appended sequentially."""
cache = torch.zeros(
(self.batch_size, self.max_seq_len, self.num_heads, self.head_dim),
dtype=torch.float32,
)

for pos in range(10):
value = torch.randn(
(self.batch_size, 1, self.num_heads, self.head_dim),
dtype=torch.float32,
)
torch.ops.llama.update_cache(value, cache, pos)
self.assertTrue(
torch.allclose(cache[:, pos : pos + 1, :, :], value),
f"Cache mismatch at position {pos}",
)

def test_custom_kv_cache_lazy_init(self):
"""Verify CustomKVCache with lazy=True creates zero-sized buffers."""
from executorch.examples.models.llama.source_transformation.custom_kv_cache import (
CustomKVCache,
)

cache = CustomKVCache(
max_batch_size=1,
max_context_length=131072, # 128K ceiling
n_heads=4,
head_dim=8,
dtype=torch.float32,
lazy=True,
)
self.assertEqual(cache.k_cache.shape[1], 0, "Lazy k_cache seq dim should be 0")
self.assertEqual(cache.v_cache.shape[1], 0, "Lazy v_cache seq dim should be 0")
self.assertEqual(cache.max_context_length, 131072)

def test_custom_kv_cache_non_lazy_init(self):
"""Verify CustomKVCache without lazy=True creates full-sized buffers."""
from executorch.examples.models.llama.source_transformation.custom_kv_cache import (
CustomKVCache,
)

cache = CustomKVCache(
max_batch_size=1,
max_context_length=64,
n_heads=4,
head_dim=8,
dtype=torch.float32,
lazy=False,
)
self.assertEqual(cache.k_cache.shape[1], 64)
self.assertEqual(cache.v_cache.shape[1], 64)

def test_replace_kv_cache_with_lazy(self):
"""Verify replace_kv_cache_with_custom_kv_cache passes lazy flag."""
from executorch.examples.models.llama.attention import KVCache
from executorch.examples.models.llama.source_transformation.custom_kv_cache import (
CustomKVCache,
replace_kv_cache_with_custom_kv_cache,
)

class FakeModel(torch.nn.Module):
def __init__(self):
super().__init__()
# KVCache stores as [B, H, S, D]
self.kv = KVCache(
max_batch_size=1,
max_context_length=128,
n_heads=4,
head_dim=8,
enable_dynamic_shape=False,
dtype=torch.float32,
)

def forward(self, x):
return x

model = FakeModel()
replace_kv_cache_with_custom_kv_cache(model, lazy=True)
self.assertIsInstance(model.kv, CustomKVCache)
# CustomKVCache stores as [B, S, H, D], lazy means seq_dim=0
self.assertEqual(model.kv.k_cache.shape[1], 0)


class LazyKVCacheMetaKernelTest(unittest.TestCase):
"""Test that meta kernels work without upper-bound cache size checks."""

def test_meta_kernel_allows_start_pos_beyond_cache(self):
"""Meta kernel should not reject start_pos >= cache.size(1)."""
value = torch.randn(1, 1, 4, 8)
# Cache with seq_len=0 (lazy)
cache = torch.zeros(1, 0, 4, 8)
# This should not raise — the runtime op handles resize
result = torch.ops.llama.update_cache(value, cache, 0)
self.assertIsNotNone(result)

def test_meta_kernel_allows_large_start_pos(self):
"""Meta kernel should allow start_pos beyond current cache size for lazy caches."""
value = torch.randn(1, 1, 4, 8)
cache = torch.zeros(1, 0, 4, 8)
# Lazy cache (size(1)==0) skips bounds checks — runtime op handles resize
result = torch.ops.llama.update_cache(value, cache, 100)
self.assertIsNotNone(result)


if __name__ == "__main__":
unittest.main()
Loading
Loading