Skip to content

Commit 944b033

Browse files
psiddhclaude
andcommitted
DYNAMIC_UNBOUND support for portable runtime: lazy KV cache allocation
Enable DYNAMIC_UNBOUND tensors in the portable runtime, allowing KV cache buffers to be dynamically managed rather than statically memory-planned. This is the architectural foundation for pay-as-you-go memory allocation in ExecuTorch LLM inference. Core changes: - DynamicAllocator interface with allocate/reallocate/free - PalDynamicAllocator default impl (PAL-backed, 2x growth policy) - TrackingDynamicAllocator for memory stats observability - MemoryManager gains 4th slot for DynamicAllocator (backward compatible) - TensorImpl gains dynamic_allocator_ and capacity_bytes_ fields - TensorImpl::internal_resize_contiguous handles DYNAMIC_UNBOUND resize - tensor_parser_portable.cpp: remove DYNAMIC_UNBOUND rejection, wire up allocator at load time for tensors with no memory-planned data - method.cpp: FreeCall frees dynamic memory; destructor cleans up all - Module API auto-creates PalDynamicAllocator (DYNAMIC_UNBOUND just works) Export changes: - MarkDynamicUnboundPass marks KV cache buffers as DYNAMIC_UNBOUND - --lazy_kv_cache flag for Llama export Co-authored-by: Claude <noreply@anthropic.com>
1 parent 02bad9d commit 944b033

17 files changed

Lines changed: 522 additions & 12 deletions

examples/models/llama/export_llama_lib.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,15 @@ def build_args_parser() -> argparse.ArgumentParser:
418418
help="maximum length of context for model to remember",
419419
)
420420

421+
parser.add_argument(
422+
"--lazy_kv_cache",
423+
action="store_true",
424+
default=False,
425+
help="Mark KV cache buffers as DYNAMIC_UNBOUND so they are allocated "
426+
"lazily at runtime instead of at load time. Reduces initial memory "
427+
"usage when max_context_length is large.",
428+
)
429+
421430
parser.add_argument(
422431
"--local_global_attention",
423432
type=parse_list_of_ints,
@@ -1362,6 +1371,13 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
13621371
if llm_config.base.model_class.value in TORCHTUNE_DEFINED_MODELS:
13631372
additional_passes = [InitializedMutableBufferPass(["kv_cache_pos"])]
13641373

1374+
if llm_config.export.lazy_kv_cache:
1375+
from executorch.exir.passes.mark_dynamic_unbound_pass import (
1376+
MarkDynamicUnboundPass,
1377+
)
1378+
1379+
additional_passes.append(MarkDynamicUnboundPass())
1380+
13651381
# export_to_edge
13661382
builder_manager = _prepare_for_llama_export(llm_config)
13671383
if (
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import List, Optional
8+
9+
from executorch.exir.pass_base import ExportPass
10+
11+
12+
class MarkDynamicUnboundPass(ExportPass):
13+
"""
14+
Marks matching placeholder nodes with ``et_dynamic_unbound`` metadata.
15+
16+
After ``SpecPropPass`` creates ``TensorSpec`` for each placeholder,
17+
``update_placeholder_tensor_specs`` reads this flag and sets the spec's
18+
``shape_dynamism`` to ``DYNAMIC_UNBOUND``. The memory planner then skips
19+
those tensors, and the runtime allocates their memory lazily via
20+
``DynamicAllocator``.
21+
22+
Typical usage: mark KV cache buffers so they start unallocated and grow
23+
on demand, avoiding the full upfront memory cost of max_context_length.
24+
"""
25+
26+
def __init__(
27+
self,
28+
name_patterns: Optional[List[str]] = None,
29+
) -> None:
30+
super().__init__()
31+
self.name_patterns = name_patterns or ["k_cache", "v_cache"]
32+
33+
def placeholder(self, name: str, arg, meta):
34+
if any(pattern in name for pattern in self.name_patterns):
35+
meta["et_dynamic_unbound"] = True
36+
return super().placeholder(name, arg, meta)

exir/passes/spec_prop_pass.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import torch
1313
from executorch.exir.delegate import executorch_call_delegate
1414
from executorch.exir.pass_base import ExportPass, ProxyValue
15+
from executorch.exir.schema import TensorShapeDynamism
1516
from executorch.exir.tensor import TensorSpec
1617
from torch.export.exported_program import ExportGraphSignature
1718
from torch.fx.node import Node
@@ -121,3 +122,7 @@ def update_placeholder_tensor_specs(
121122
in exported_program.graph_signature.inputs_to_lifted_tensor_constants
122123
):
123124
spec.const = True
125+
if isinstance(spec, TensorSpec) and node.meta.get(
126+
"et_dynamic_unbound", False
127+
):
128+
spec.shape_dynamism = TensorShapeDynamism.DYNAMIC_UNBOUND

extension/llm/export/config/llm_config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,7 @@ class ExportConfig:
256256
export_only: bool = False
257257
foundation_weights_file: Optional[str] = None
258258
lora_weights_file: Optional[str] = None
259+
lazy_kv_cache: bool = False
259260

260261
def __post_init__(self):
261262
if self.max_context_length < self.max_seq_length:
@@ -695,6 +696,8 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901
695696
llm_config.export.foundation_weights_file = args.foundation_weights_file
696697
if hasattr(args, "lora_weights_file"):
697698
llm_config.export.lora_weights_file = args.lora_weights_file
699+
if hasattr(args, "lazy_kv_cache"):
700+
llm_config.export.lazy_kv_cache = args.lazy_kv_cache
698701

699702
# QuantizationConfig
700703
if hasattr(args, "quantization_mode"):

extension/module/module.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include <executorch/extension/flat_tensor/flat_tensor_data_map.h>
1414
#include <executorch/extension/memory_allocator/malloc_memory_allocator.h>
1515
#include <executorch/extension/named_data_map/merged_data_map.h>
16+
#include <executorch/runtime/executor/pal_dynamic_allocator.h>
1617
#include <executorch/runtime/platform/runtime.h>
1718

1819
namespace executorch {
@@ -389,8 +390,13 @@ runtime::Error Module::load_method(
389390
planned_memory = method_holder.planned_memory->planned_memory.get();
390391
}
391392

393+
method_holder.dynamic_allocator =
394+
std::make_unique<runtime::PalDynamicAllocator>();
392395
method_holder.memory_manager = std::make_unique<runtime::MemoryManager>(
393-
memory_allocator_.get(), planned_memory, temp_allocator_.get());
396+
memory_allocator_.get(),
397+
planned_memory,
398+
temp_allocator_.get(),
399+
method_holder.dynamic_allocator.get());
394400
auto res_method = program_->load_method(
395401
method_name.c_str(),
396402
method_holder.memory_manager.get(),

extension/module/module.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include <unordered_set>
1515
#include <vector>
1616

17+
#include <executorch/runtime/executor/dynamic_allocator.h>
1718
#include <executorch/runtime/executor/program.h>
1819

1920
#ifdef USE_ATEN_LIB
@@ -694,6 +695,7 @@ class Module {
694695

695696
struct MethodHolder {
696697
std::unique_ptr<PlannedMemory> planned_memory;
698+
std::unique_ptr<runtime::DynamicAllocator> dynamic_allocator;
697699
std::unique_ptr<runtime::MemoryManager> memory_manager;
698700
std::unique_ptr<Method> method;
699701
};

extension/module/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def define_common_targets():
2525
"//executorch/extension/data_loader:mmap_data_loader",
2626
"//executorch/extension/flat_tensor:flat_tensor_data_map" + aten_suffix,
2727
"//executorch/extension/named_data_map:merged_data_map" + aten_suffix,
28+
"//executorch/runtime/executor:pal_dynamic_allocator",
2829
],
2930
exported_deps = [
3031
"//executorch/runtime/executor:program_no_prim_ops" + aten_suffix,

runtime/core/portable_type/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def define_common_targets():
4141
"//executorch/runtime/core/exec_aten/util:scalar_type_util",
4242
"//executorch/runtime/core/exec_aten/util:tensor_shape_to_c_string",
4343
"//executorch/runtime/core:tag",
44+
"//executorch/runtime/executor:dynamic_allocator",
4445
],
4546
)
4647

runtime/core/portable_type/tensor_impl.cpp

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,12 +113,59 @@ Error TensorImpl::internal_resize_contiguous(ArrayRef<SizesType> new_sizes) {
113113
}
114114

115115
break;
116-
case TensorShapeDynamism::DYNAMIC_BOUND:
117-
// TODO(T175194371): Unbounded dynamic tensor resizing is not yet
118-
// supported: treat them as upper-bounded.
119116
case TensorShapeDynamism::DYNAMIC_UNBOUND: {
120117
const auto new_numel = compute_numel(new_sizes.data(), dim_);
121118

119+
ET_CHECK_OR_RETURN_ERROR(
120+
static_cast<size_t>(new_numel) <= numel_bound_,
121+
NotSupported,
122+
"Attempted to resize a dynamic unbound tensor beyond its ceiling of %zu elements to %zu elements.",
123+
numel_bound_,
124+
new_numel);
125+
126+
const size_t needed_bytes =
127+
static_cast<size_t>(new_numel) * elementSize(type_);
128+
// If capacity_bytes_ is 0 but data_ is non-null, the buffer is
129+
// externally managed (e.g., stack-allocated in tests). Use the
130+
// original numel bound as the effective capacity.
131+
const size_t effective_capacity = capacity_bytes_ > 0
132+
? capacity_bytes_
133+
: static_cast<size_t>(numel_bound_) * elementSize(type_);
134+
if (needed_bytes > effective_capacity) {
135+
ET_CHECK_OR_RETURN_ERROR(
136+
dynamic_allocator_ != nullptr,
137+
NotSupported,
138+
"DYNAMIC_UNBOUND tensor needs reallocation but has no DynamicAllocator");
139+
size_t actual_size = 0;
140+
void* new_data = dynamic_allocator_->reallocate(
141+
data_,
142+
capacity_bytes_,
143+
needed_bytes,
144+
alignof(std::max_align_t),
145+
&actual_size);
146+
ET_CHECK_OR_RETURN_ERROR(
147+
new_data != nullptr,
148+
MemoryAllocationFailed,
149+
"Failed to reallocate DYNAMIC_UNBOUND tensor to %zu bytes",
150+
needed_bytes);
151+
data_ = new_data;
152+
capacity_bytes_ = actual_size;
153+
}
154+
155+
if (strides_ && dim_order_) {
156+
auto error =
157+
dim_order_to_stride(new_sizes.data(), dim_order_, dim_, strides_);
158+
if (error != Error::Ok) {
159+
return error;
160+
}
161+
}
162+
numel_ = new_numel;
163+
std::copy(new_sizes.begin(), new_sizes.end(), sizes_);
164+
} break;
165+
166+
case TensorShapeDynamism::DYNAMIC_BOUND: {
167+
const auto new_numel = compute_numel(new_sizes.data(), dim_);
168+
122169
ET_CHECK_OR_RETURN_ERROR(
123170
static_cast<size_t>(new_numel) <= numel_bound_,
124171
NotSupported,

runtime/core/portable_type/tensor_impl.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <executorch/runtime/core/error.h>
1313
#include <executorch/runtime/core/portable_type/scalar_type.h>
1414
#include <executorch/runtime/core/tensor_shape_dynamism.h>
15+
#include <executorch/runtime/executor/dynamic_allocator.h>
1516

1617
// Forward declaration of a helper that provides access to internal resizing
1718
// methods of TensorImpl. Real definition is in
@@ -203,6 +204,26 @@ class TensorImpl {
203204
data_ = ptr;
204205
}
205206

207+
/// Returns the dynamic allocator for DYNAMIC_UNBOUND tensors, or nullptr.
208+
DynamicAllocator* dynamic_allocator() const {
209+
return dynamic_allocator_;
210+
}
211+
212+
/// Sets the dynamic allocator for lazy allocation.
213+
void set_dynamic_allocator(DynamicAllocator* allocator) {
214+
dynamic_allocator_ = allocator;
215+
}
216+
217+
/// Returns the capacity in bytes of the current dynamic allocation.
218+
size_t capacity_bytes() const {
219+
return capacity_bytes_;
220+
}
221+
222+
/// Sets the capacity in bytes of the current dynamic allocation.
223+
void set_capacity_bytes(size_t capacity) {
224+
capacity_bytes_ = capacity;
225+
}
226+
206227
/*
207228
* DEPRECATED: Use torch::executor::resize_tensor() or
208229
* torch::executor::resize_tensor_impl().
@@ -261,6 +282,13 @@ class TensorImpl {
261282

262283
/// Specifies the mutability of the shape of the tensor.
263284
const TensorShapeDynamism shape_dynamism_;
285+
286+
/// Allocator for DYNAMIC_UNBOUND tensors. nullptr for other dynamism types.
287+
DynamicAllocator* dynamic_allocator_ = nullptr;
288+
289+
/// Capacity in bytes of the buffer pointed to by data_, when managed by
290+
/// dynamic_allocator_. 0 means no allocation yet.
291+
size_t capacity_bytes_ = 0;
264292
};
265293

266294
/**

0 commit comments

Comments
 (0)