Skip to content

Commit b2022bc

Browse files
committed
[ET Device Support] Device-aware memory planning: separate buffers per device type
Pull Request resolved: #18375 Extends memory planning to separate device tensors from CPU tensors into distinct memory buffers. Non-CPU TensorSpecs (e.g., CUDA) are pre-assigned device-specific mem_ids before the greedy/naive algorithm runs, ensuring they get planned into independent memory buffers that never share space with CPU tensors. ghstack-source-id: 357060891 @exported-using-ghexport Differential Revision: [D97447105](https://our.internmc.facebook.com/intern/diff/D97447105/)
1 parent 1e9fbed commit b2022bc

6 files changed

Lines changed: 417 additions & 29 deletions

File tree

docs/source/compiler-memory-planning.md

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,46 @@ program = edge_program.to_executorch(
8282
)
8383
```
8484

85+
> **Note:** Custom pool passes that pre-assign `mem_id` are not yet compatible
86+
> with `enable_non_cpu_memory_planning=True`. When per-device planning is
87+
> enabled, device buffers are appended after the CPU buffers in the global
88+
> `bufsizes` array. If a custom pass has already set `mem_id` values (e.g.
89+
> `mem_id=2` or `mem_id=3`), those slots may collide with the device-buffer
90+
> slots, leading to incorrect memory layout. If both features are enabled
91+
> simultaneously, `apply_algo` will raise a `NotImplementedError`.
92+
8593
Users attempting to write a custom memory planning algorithm should start by looking at [the greedy algorithm's implementation](https://github.com/pytorch/executorch/blob/d62c41ca86435e5316e7ed292b6d68aff27a2fb7/exir/memory_planning.py#L459C1-L459C12).
8694

95+
## Device-Aware Memory Planning
96+
97+
When `enable_non_cpu_memory_planning=True` is set on `ExecutorchBackendConfig`,
98+
the memory planning pass partitions tensor specs by their device type and runs
99+
the planning algorithm independently for each device. This produces separate
100+
memory buffers for each device (e.g. CPU vs. CUDA), ensuring that device memory
101+
and host memory are never mixed.
102+
103+
```python
104+
program = edge_program.to_executorch(
105+
exir.ExecutorchBackendConfig(
106+
enable_non_cpu_memory_planning=True,
107+
)
108+
)
109+
```
110+
111+
The resulting `bufsizes` array layout depends on which devices are present:
112+
113+
| Scenario | bufsizes | Description |
114+
|---|---|---|
115+
| CPU only | `[0, cpu_size]` | Same as legacy behavior |
116+
| CUDA only | `[0, cuda_size]` | Buffer 1 is CUDA, no wasted CPU slot |
117+
| CPU + CUDA | `[0, cpu_size, cuda_size]` | Buffer 1 is CPU, buffer 2 is CUDA |
118+
119+
**Current limitations:**
120+
- Not compatible with custom pool passes that pre-assign `spec.mem_id` (see note above).
121+
- Submodule buffer sizes (from control-flow submodules like `cond`/`while`/`map`)
122+
are applied only to the CPU partition. This is safe today because on-device
123+
tensors only appear as delegate blob I/O, never inside control-flow submodules.
124+
87125
## Debugging Tool
88126

89127
Please refer to [Memory Planning Inspection](memory-planning-inspection.md) for a tool to inspect the result of memory planning.

exir/capture/_config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,3 +117,9 @@ class ExecutorchBackendConfig:
117117

118118
# Experimental: If set to true, we run a pass to reinplace ops in the graph.
119119
run_reinplace_pass: bool = False
120+
121+
# When True, memory planning partitions specs by device and runs the
122+
# algorithm independently per device, producing separate buffers for CPU
123+
# vs. accelerator memory. Default False preserves the legacy behavior
124+
# where all tensors are planned into CPU memory regardless of device.
125+
enable_non_cpu_memory_planning: bool = False

exir/memory_planning.py

Lines changed: 133 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from executorch.exir.delegate import executorch_call_delegate
3232
from executorch.exir.error import internal_assert, InternalError
3333
from executorch.exir.operator.convert import is_inplace_variant, is_out_variant
34-
from executorch.exir.schema import TensorShapeDynamism
34+
from executorch.exir.schema import DeviceType, NonConstBufferDevice, TensorShapeDynamism
3535
from executorch.exir.tensor import TensorSpec
3636
from torch import fx
3737
from torch.export.exported_program import (
@@ -1211,10 +1211,19 @@ def apply_algo(
12111211
alloc_graph_input: bool = True,
12121212
alloc_graph_output: bool = True,
12131213
alloc_mutable_buffers: bool = True,
1214+
enable_non_cpu_memory_planning: bool = False,
12141215
) -> list[int]:
12151216
"""
12161217
Recursively apply algo to graph_module and its submodules for control flow.
12171218
1219+
Partitions specs by device type and device idx, and runs the memory planning
1220+
algorithm independently per device, then merges results into separate buffers.
1221+
This ensures device memory and CPU memory are never mixed.
1222+
1223+
When enable_non_cpu_memory_planning is False (default), all specs are planned
1224+
into a single CPU memory pool regardless of their device attribute. This
1225+
preserves the legacy behavior. Set to True to enable per-device partitioning.
1226+
12181227
Algo implementation should handle one of two meta entries for submodules:
12191228
1. input_mem_buffer_sizes: List of int offset bytes. Memory allocated by
12201229
`algo` should start at the offset specified by this list;
@@ -1229,49 +1238,144 @@ def apply_algo(
12291238
`operand` arg. The memory for operands is unused.
12301239
"""
12311240
# Extract the nodes and their lifespans from the graph_module
1232-
# Difficult to just filter the list of specs returned by this due to
1233-
# how we flag trainable weights.
12341241
_ = update_all_tensors_lifetime(graph_module, graph_signature)
12351242

1236-
# Filter specs based on alloc_graph_input and alloc_graph_output
1237-
specs = collect_specs_from_nodes(
1238-
graph_module.graph.nodes,
1239-
graph_signature,
1240-
do_assertion=False,
1241-
ignore_graph_input=not alloc_graph_input,
1242-
ignore_graph_output=not alloc_graph_output,
1243-
ignore_mutable_buffers=not alloc_mutable_buffers,
1243+
# Collect and materialize specs into a set so we can iterate multiple
1244+
# times and partition by device.
1245+
all_specs: set[TensorSpec] = set(
1246+
collect_specs_from_nodes(
1247+
graph_module.graph.nodes,
1248+
graph_signature,
1249+
do_assertion=False,
1250+
ignore_graph_input=not alloc_graph_input,
1251+
ignore_graph_output=not alloc_graph_output,
1252+
ignore_mutable_buffers=not alloc_mutable_buffers,
1253+
)
12441254
)
12451255

12461256
# Get temporary specs for submodules to set aside space during execution
12471257
# of submodules.
1258+
# NOTE: submodule_bufsizes are currently applied only to the CPU partition.
1259+
# This assumes all control-flow submodule tensors (cond/while/map) live in
1260+
# CPU memory. Today this is safe because on-device tensors only appear as
1261+
# delegate blob I/O, which never lives inside control-flow submodules.
1262+
# If device tensors ever appear in submodules, _apply_algo_to_submodules
1263+
# will need per-device partitioning as well.
12481264
submodule_bufsizes = _apply_algo_to_submodules(
12491265
algo, graph_module, alignment, graph_signature
12501266
)
12511267

1252-
# Update `input_mem_buffer_sizes` in graph_module. This will allow existing
1253-
# algos to work using `input_mem_buffer_sizes` or use
1254-
# `non_const_buffer_sizes` directly.
1255-
# pyre-ignore[16]: `torch.fx.GraphModule` has no attribute `input_mem_buffer_sizes`.
1256-
graph_module.input_mem_buffer_sizes = submodule_bufsizes
1257-
12581268
# Get extra padding for XNNPACK if needed
12591269
extra_padding = 0
12601270
if _contains_xnnpack_delegate(graph_module):
12611271
extra_padding = 64
12621272

1263-
# Pass the filtered specs to the algorithm
1264-
bufsizes: list[int] = algo(
1265-
alignment,
1266-
specs,
1267-
graph_module,
1268-
graph_signature,
1269-
extra_padding,
1273+
# 1. Partition specs by (device_type, device_index).
1274+
# Different device indices on the same device type (e.g. CUDA:0 vs CUDA:1)
1275+
# get separate memory buffers.
1276+
_CPU_KEY: tuple[DeviceType, int] = (DeviceType.CPU, 0)
1277+
specs_by_device: dict[tuple[DeviceType, int], set[TensorSpec]] = defaultdict(set)
1278+
if enable_non_cpu_memory_planning:
1279+
has_non_cpu_specs = False
1280+
has_pre_assigned_mem_id = False
1281+
for spec in all_specs:
1282+
device_key = (spec.device, spec.device_index)
1283+
specs_by_device[device_key].add(spec)
1284+
if spec.device != DeviceType.CPU:
1285+
has_non_cpu_specs = True
1286+
if spec.mem_id is not None:
1287+
has_pre_assigned_mem_id = True
1288+
1289+
# Custom pool passes pre-assign mem_ids (e.g. mem_id=2, 3, …) to place
1290+
# tensors into specific memory arenas. Per-device partitioning appends
1291+
# device buffers after the CPU buffers, and the remap formula
1292+
# global_mem_id = (local_mem_id - 1) + base_mem_id
1293+
# assumes the algo-local numbering starts at 1. If a custom pass has
1294+
# already set mem_ids > 1 on the CPU side, the device-buffer slots may
1295+
# collide with those custom pool slots.
1296+
# TODO(gasoonjia): support custom pools + per-device planning by reserving
1297+
# device slots after the highest custom pool id.
1298+
if has_non_cpu_specs and has_pre_assigned_mem_id:
1299+
raise NotImplementedError(
1300+
"enable_non_cpu_memory_planning is not yet compatible with "
1301+
"custom memory pool passes that pre-assign spec.mem_id. "
1302+
"The per-device buffer slots may collide with custom pool "
1303+
"mem_ids. Please disable enable_non_cpu_memory_planning or "
1304+
"remove the custom mem_id assignments."
1305+
)
1306+
else:
1307+
# Legacy behavior: all specs planned into CPU memory regardless of device
1308+
specs_by_device[_CPU_KEY] = all_specs
1309+
1310+
# 2. Plan each device independently
1311+
global_bufsizes: list[int] = [0] # index 0 reserved for constants
1312+
# Track (device_type, device_index) for each buffer slot
1313+
buffer_devices: list[tuple[DeviceType, int]] = [_CPU_KEY]
1314+
1315+
# Process CPU:0 first (if present), then other devices sorted by
1316+
# (type.value, index) so the ordering is deterministic.
1317+
device_order = sorted(
1318+
specs_by_device.keys(),
1319+
key=lambda dk: (dk != _CPU_KEY, dk[0].value, dk[1]),
12701320
)
12711321

1272-
# pyre-ignore[6]: Incompatible parameter type [6]
1273-
# In call `insert_calls_to_free`, for 2nd positional argument, expected `Set[TensorSpec]` but got `Iterable[TensorSpec]`
1274-
insert_calls_to_free(graph_module, specs)
1322+
for device_key in device_order:
1323+
device_specs = specs_by_device[device_key]
12751324

1276-
graph_module.meta.update({"non_const_buffer_sizes": bufsizes})
1277-
return bufsizes
1325+
# Only apply submodule pre-allocation for CPU specs; device buffers
1326+
# do not share memory space with CPU submodule arenas.
1327+
# pyre-ignore[16]: `torch.fx.GraphModule` has no attribute `input_mem_buffer_sizes`.
1328+
graph_module.input_mem_buffer_sizes = (
1329+
submodule_bufsizes if device_key == _CPU_KEY else []
1330+
)
1331+
1332+
# Run algorithm independently on this device's specs
1333+
device_bufsizes = algo(
1334+
alignment, device_specs, graph_module, graph_signature, extra_padding
1335+
)
1336+
1337+
# Calculate base mem_id in global space
1338+
base_mem_id = len(global_bufsizes)
1339+
1340+
# Append buffer sizes (skip index 0 which is constants placeholder)
1341+
global_bufsizes.extend(device_bufsizes[1:])
1342+
1343+
# Track device key for each new buffer slot
1344+
for _ in device_bufsizes[1:]:
1345+
buffer_devices.append(device_key)
1346+
1347+
# Remap spec mem_ids from algo-local to global.
1348+
# At this point spec.mem_id has been set by MemoryPlanningAlgorithmSuite:
1349+
# the suite runs each algorithm (e.g. greedy), picks the best result,
1350+
# and writes the winning mem_id/mem_offset/mem_obj_id back onto each
1351+
# spec. For specs with no pre-assigned mem_id the algorithm defaults
1352+
# to mem_id=1; custom-pool passes may pre-assign other values (e.g. 3).
1353+
# We remap from the algo-local numbering (1-based) to the global
1354+
# position: global_mem_id = (local_mem_id - 1) + base_mem_id.
1355+
for spec in device_specs:
1356+
if spec.mem_id is not None:
1357+
spec.mem_id = (spec.mem_id - 1) + base_mem_id
1358+
1359+
# Ensure backward compatibility: at least [0, 0] when no specs exist
1360+
if len(global_bufsizes) < 2:
1361+
global_bufsizes.append(0)
1362+
buffer_devices.append(_CPU_KEY)
1363+
1364+
# 3. Insert free calls and build device buffer mapping
1365+
insert_calls_to_free(graph_module, all_specs)
1366+
1367+
# Only record non-CPU buffer entries. CPU buffers are the default and
1368+
# do not need explicit device metadata in the serialized program.
1369+
non_const_buffer_device: Optional[list[NonConstBufferDevice]] = None
1370+
has_device_buffers = any(dk[0] != DeviceType.CPU for dk in buffer_devices)
1371+
if has_device_buffers:
1372+
non_const_buffer_device = [
1373+
NonConstBufferDevice(buffer_idx=i, device_type=dt, device_index=di)
1374+
for i, (dt, di) in enumerate(buffer_devices)
1375+
if (dt, di) != _CPU_KEY
1376+
]
1377+
1378+
graph_module.meta["non_const_buffer_sizes"] = global_bufsizes
1379+
if non_const_buffer_device is not None:
1380+
graph_module.meta["non_const_buffer_device"] = non_const_buffer_device
1381+
return global_bufsizes

exir/passes/memory_planning_pass.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ def __init__(
153153
alloc_mutable_buffers: bool = True,
154154
share_mutable_buffers: bool = False,
155155
alignment: int = ALIGNMENT,
156+
enable_non_cpu_memory_planning: bool = False,
156157
) -> None:
157158
r"""
158159
alloc_graph_input/alloc_graph_output will have 4 different combinations
@@ -173,6 +174,7 @@ def __init__(
173174
self.alloc_mutable_buffers = alloc_mutable_buffers
174175
self.share_mutable_buffers = share_mutable_buffers
175176
self.alignment = alignment
177+
self.enable_non_cpu_memory_planning = enable_non_cpu_memory_planning
176178
self.state = _MemoryPlanningState()
177179

178180
def _set_alloc_node_spec(self, graph_module: torch.fx.GraphModule) -> None:
@@ -250,6 +252,7 @@ def run(
250252
# If mutable buffers are shared, then do not allocate them in the
251253
# main memory planning algo; they are allocated in run_multimethod.
252254
self.alloc_mutable_buffers and not self.share_mutable_buffers,
255+
self.enable_non_cpu_memory_planning,
253256
)
254257

255258
if self.share_mutable_buffers and graph_signature is not None:

exir/program/_program.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1788,6 +1788,12 @@ def to_executorch( # noqa (FLAKE8) C901
17881788
)
17891789
else:
17901790
memory_planning_pass = config.memory_planning_pass
1791+
# Propagate enable_non_cpu_memory_planning from the top-level config
1792+
# to the pass instance so that device-aware partitioning is applied.
1793+
if hasattr(memory_planning_pass, "enable_non_cpu_memory_planning"):
1794+
memory_planning_pass.enable_non_cpu_memory_planning = (
1795+
config.enable_non_cpu_memory_planning
1796+
)
17911797
# TODO(jakeszwe): Follow up with compiler on if the deepcopy is necessary and if so how to make it work
17921798
if hasattr(memory_planning_pass, "run"):
17931799
new_gm_res = memory_planning_pass.run(new_gm, new_signature)

0 commit comments

Comments
 (0)