Skip to content

Commit 6bb6ab3

Browse files
authored
[ET Device Support] Device-aware memory planning: separate buffers per device type (#18375)
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * #18474 * #18473 * #18472 * __->__ #18375 * #19497 * #19496 Extends memory planning to separate device tensors from CPU tensors into distinct memory buffers. Non-CPU TensorSpecs (e.g., CUDA) are pre-assigned device-specific mem_ids before the greedy/naive algorithm runs, ensuring they get planned into independent memory buffers that never share space with CPU tensors. Differential Revision: [D97447105](https://our.internmc.facebook.com/intern/diff/D97447105/)
1 parent 50d6d05 commit 6bb6ab3

6 files changed

Lines changed: 441 additions & 29 deletions

File tree

docs/source/compiler-memory-planning.md

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,46 @@ program = edge_program.to_executorch(
8282
)
8383
```
8484

85+
> **Note:** Custom pool passes that pre-assign `mem_id` are not yet compatible
86+
> with `enable_non_cpu_memory_planning=True`. When per-device planning is
87+
> enabled, device buffers are appended after the CPU buffers in the global
88+
> `bufsizes` array. If a custom pass has already set `mem_id` values (e.g.
89+
> `mem_id=2` or `mem_id=3`), those slots may collide with the device-buffer
90+
> slots, leading to incorrect memory layout. If both features are enabled
91+
> simultaneously, `apply_algo` will raise a `NotImplementedError`.
92+
8593
Users attempting to write a custom memory planning algorithm should start by looking at [the greedy algorithm's implementation](https://github.com/pytorch/executorch/blob/main/exir/memory_planning.py#L801).
8694

95+
## Device-Aware Memory Planning
96+
97+
When `enable_non_cpu_memory_planning=True` is set on `ExecutorchBackendConfig`,
98+
the memory planning pass partitions tensor specs by their device type and runs
99+
the planning algorithm independently for each device. This produces separate
100+
memory buffers for each device (e.g. CPU vs. CUDA), ensuring that device memory
101+
and host memory are never mixed.
102+
103+
```python
104+
program = edge_program.to_executorch(
105+
exir.ExecutorchBackendConfig(
106+
enable_non_cpu_memory_planning=True,
107+
)
108+
)
109+
```
110+
111+
The resulting `bufsizes` array layout depends on which devices are present:
112+
113+
| Scenario | bufsizes | Description |
114+
|---|---|---|
115+
| CPU only | `[0, cpu_size]` | Same as legacy behavior |
116+
| CUDA only | `[0, cuda_size]` | Buffer 1 is CUDA, no wasted CPU slot |
117+
| CPU + CUDA | `[0, cpu_size, cuda_size]` | Buffer 1 is CPU, buffer 2 is CUDA |
118+
119+
**Current limitations:**
120+
- Not compatible with custom pool passes that pre-assign `spec.mem_id` (see note above).
121+
- Submodule buffer sizes (from control-flow submodules like `cond`/`while`/`map`)
122+
are applied only to the CPU partition. This is safe today because on-device
123+
tensors only appear as delegate blob I/O, never inside control-flow submodules.
124+
87125
## Debugging Tool
88126

89127
Please refer to [Memory Planning Inspection](memory-planning-inspection.md) for a tool to inspect the result of memory planning.

exir/capture/_config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,3 +117,9 @@ class ExecutorchBackendConfig:
117117

118118
# Experimental: If set to true, we run a pass to reinplace ops in the graph.
119119
run_reinplace_pass: bool = False
120+
121+
# When True, memory planning partitions specs by device and runs the
122+
# algorithm independently per device, producing separate buffers for CPU
123+
# vs. accelerator memory. Default False preserves the legacy behavior
124+
# where all tensors are planned into CPU memory regardless of device.
125+
enable_non_cpu_memory_planning: bool = False

exir/memory_planning.py

Lines changed: 157 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from executorch.exir.delegate import executorch_call_delegate
3232
from executorch.exir.error import internal_assert, InternalError
3333
from executorch.exir.operator.convert import is_inplace_variant, is_out_variant
34-
from executorch.exir.schema import TensorShapeDynamism
34+
from executorch.exir.schema import DeviceType, NonConstBufferDevice, TensorShapeDynamism
3535
from executorch.exir.tensor import TensorSpec
3636
from torch import fx
3737
from torch.export.exported_program import (
@@ -1203,6 +1203,74 @@ def _handle(
12031203
return bufsizes
12041204

12051205

1206+
_CPU_KEY: tuple[DeviceType, int] = (DeviceType.CPU, 0)
1207+
1208+
1209+
def _partition_specs_by_device(
1210+
all_specs: set[TensorSpec],
1211+
enable_non_cpu_memory_planning: bool,
1212+
) -> dict[tuple[DeviceType, int], set[TensorSpec]]:
1213+
"""Partition specs by (device_type, device_index).
1214+
1215+
Different device indices on the same device type (e.g. CUDA:0 vs CUDA:1)
1216+
get separate memory buffers.
1217+
1218+
When ``enable_non_cpu_memory_planning`` is False (legacy), all specs are
1219+
placed into a single CPU:0 bucket regardless of their device attribute.
1220+
"""
1221+
specs_by_device: dict[tuple[DeviceType, int], set[TensorSpec]] = defaultdict(set)
1222+
if not enable_non_cpu_memory_planning:
1223+
specs_by_device[_CPU_KEY] = all_specs
1224+
return specs_by_device
1225+
1226+
has_non_cpu_specs = False
1227+
has_pre_assigned_mem_id = False
1228+
for spec in all_specs:
1229+
device_key = (spec.device, spec.device_index)
1230+
specs_by_device[device_key].add(spec)
1231+
if spec.device != DeviceType.CPU:
1232+
has_non_cpu_specs = True
1233+
if spec.mem_id is not None:
1234+
has_pre_assigned_mem_id = True
1235+
1236+
# Custom pool passes pre-assign mem_ids (e.g. mem_id=2, 3, …) to place
1237+
# tensors into specific memory arenas. Per-device partitioning appends
1238+
# device buffers after the CPU buffers, and the remap formula
1239+
# global_mem_id = (local_mem_id - 1) + base_mem_id
1240+
# assumes the algo-local numbering starts at 1. If a custom pass has
1241+
# already set mem_ids > 1 on the CPU side, the device-buffer slots may
1242+
# collide with those custom pool slots.
1243+
# TODO(gasoonjia): support custom pools + per-device planning by reserving
1244+
# device slots after the highest custom pool id.
1245+
if has_non_cpu_specs and has_pre_assigned_mem_id:
1246+
raise NotImplementedError(
1247+
"enable_non_cpu_memory_planning is not yet compatible with "
1248+
"custom memory pool passes that pre-assign spec.mem_id. "
1249+
"The per-device buffer slots may collide with custom pool "
1250+
"mem_ids. Please disable enable_non_cpu_memory_planning or "
1251+
"remove the custom mem_id assignments."
1252+
)
1253+
1254+
return specs_by_device
1255+
1256+
1257+
def _build_non_const_buffer_device(
1258+
buffer_devices: list[tuple[DeviceType, int]],
1259+
) -> Optional[list[NonConstBufferDevice]]:
1260+
"""Build the non-CPU buffer device list for serialization.
1261+
1262+
Returns ``None`` when all buffers are CPU (the default), so that no
1263+
redundant device metadata is emitted.
1264+
"""
1265+
if not any(dk[0] != DeviceType.CPU for dk in buffer_devices):
1266+
return None
1267+
return [
1268+
NonConstBufferDevice(buffer_idx=i, device_type=dt, device_index=di)
1269+
for i, (dt, di) in enumerate(buffer_devices)
1270+
if (dt, di) != _CPU_KEY
1271+
]
1272+
1273+
12061274
def apply_algo(
12071275
algo: Callable[..., list[int]],
12081276
graph_module: torch.fx.GraphModule,
@@ -1211,10 +1279,19 @@ def apply_algo(
12111279
alloc_graph_input: bool = True,
12121280
alloc_graph_output: bool = True,
12131281
alloc_mutable_buffers: bool = True,
1282+
enable_non_cpu_memory_planning: bool = False,
12141283
) -> list[int]:
12151284
"""
12161285
Recursively apply algo to graph_module and its submodules for control flow.
12171286
1287+
Partitions specs by device type and device idx, and runs the memory planning
1288+
algorithm independently per device, then merges results into separate buffers.
1289+
This ensures device memory and CPU memory are never mixed.
1290+
1291+
When enable_non_cpu_memory_planning is False (default), all specs are planned
1292+
into a single CPU memory pool regardless of their device attribute. This
1293+
preserves the legacy behavior. Set to True to enable per-device partitioning.
1294+
12181295
Algo implementation should handle one of two meta entries for submodules:
12191296
1. input_mem_buffer_sizes: List of int offset bytes. Memory allocated by
12201297
`algo` should start at the offset specified by this list;
@@ -1229,49 +1306,100 @@ def apply_algo(
12291306
`operand` arg. The memory for operands is unused.
12301307
"""
12311308
# Extract the nodes and their lifespans from the graph_module
1232-
# Difficult to just filter the list of specs returned by this due to
1233-
# how we flag trainable weights.
12341309
_ = update_all_tensors_lifetime(graph_module, graph_signature)
12351310

1236-
# Filter specs based on alloc_graph_input and alloc_graph_output
1237-
specs = collect_specs_from_nodes(
1238-
graph_module.graph.nodes,
1239-
graph_signature,
1240-
do_assertion=False,
1241-
ignore_graph_input=not alloc_graph_input,
1242-
ignore_graph_output=not alloc_graph_output,
1243-
ignore_mutable_buffers=not alloc_mutable_buffers,
1311+
# Collect and materialize specs into a set so we can iterate multiple
1312+
# times and partition by device.
1313+
all_specs: set[TensorSpec] = set(
1314+
collect_specs_from_nodes(
1315+
graph_module.graph.nodes,
1316+
graph_signature,
1317+
do_assertion=False,
1318+
ignore_graph_input=not alloc_graph_input,
1319+
ignore_graph_output=not alloc_graph_output,
1320+
ignore_mutable_buffers=not alloc_mutable_buffers,
1321+
)
12441322
)
12451323

12461324
# Get temporary specs for submodules to set aside space during execution
12471325
# of submodules.
1326+
# NOTE: submodule_bufsizes are currently applied only to the CPU partition.
1327+
# This assumes all control-flow submodule tensors (cond/while/map) live in
1328+
# CPU memory. Today this is safe because on-device tensors only appear as
1329+
# delegate blob I/O, which never lives inside control-flow submodules.
1330+
# If device tensors ever appear in submodules, _apply_algo_to_submodules
1331+
# will need per-device partitioning as well.
12481332
submodule_bufsizes = _apply_algo_to_submodules(
12491333
algo, graph_module, alignment, graph_signature
12501334
)
12511335

1252-
# Update `input_mem_buffer_sizes` in graph_module. This will allow existing
1253-
# algos to work using `input_mem_buffer_sizes` or use
1254-
# `non_const_buffer_sizes` directly.
1255-
# pyre-ignore[16]: `torch.fx.GraphModule` has no attribute `input_mem_buffer_sizes`.
1256-
graph_module.input_mem_buffer_sizes = submodule_bufsizes
1257-
12581336
# Get extra padding for XNNPACK if needed
12591337
extra_padding = 0
12601338
if _contains_xnnpack_delegate(graph_module):
12611339
extra_padding = 64
12621340

1263-
# Pass the filtered specs to the algorithm
1264-
bufsizes: list[int] = algo(
1265-
alignment,
1266-
specs,
1267-
graph_module,
1268-
graph_signature,
1269-
extra_padding,
1341+
specs_by_device = _partition_specs_by_device(
1342+
all_specs, enable_non_cpu_memory_planning
12701343
)
12711344

1272-
# pyre-ignore[6]: Incompatible parameter type [6]
1273-
# In call `insert_calls_to_free`, for 2nd positional argument, expected `Set[TensorSpec]` but got `Iterable[TensorSpec]`
1274-
insert_calls_to_free(graph_module, specs)
1345+
# Plan each device independently
1346+
global_bufsizes: list[int] = [0] # index 0 reserved for constants
1347+
buffer_devices: list[tuple[DeviceType, int]] = [_CPU_KEY]
12751348

1276-
graph_module.meta.update({"non_const_buffer_sizes": bufsizes})
1277-
return bufsizes
1349+
# Process CPU:0 first (if present), then other devices sorted by
1350+
# (type.value, index) so the ordering is deterministic.
1351+
device_order = sorted(
1352+
specs_by_device.keys(),
1353+
key=lambda dk: (dk != _CPU_KEY, dk[0].value, dk[1]),
1354+
)
1355+
1356+
for device_key in device_order:
1357+
device_specs = specs_by_device[device_key]
1358+
1359+
# Only apply submodule pre-allocation for CPU specs; device buffers
1360+
# do not share memory space with CPU submodule arenas.
1361+
# pyre-ignore[16]: `torch.fx.GraphModule` has no attribute `input_mem_buffer_sizes`.
1362+
graph_module.input_mem_buffer_sizes = (
1363+
submodule_bufsizes if device_key == _CPU_KEY else []
1364+
)
1365+
1366+
# Run algorithm independently on this device's specs
1367+
device_bufsizes = algo(
1368+
alignment, device_specs, graph_module, graph_signature, extra_padding
1369+
)
1370+
1371+
# Calculate base mem_id in global space
1372+
base_mem_id = len(global_bufsizes)
1373+
1374+
# Append buffer sizes (skip index 0 which is constants placeholder)
1375+
global_bufsizes.extend(device_bufsizes[1:])
1376+
1377+
# Track device key for each new buffer slot
1378+
for _ in device_bufsizes[1:]:
1379+
buffer_devices.append(device_key)
1380+
1381+
# Remap spec mem_ids from algo-local to global.
1382+
# At this point spec.mem_id has been set by MemoryPlanningAlgorithmSuite:
1383+
# the suite runs each algorithm (e.g. greedy), picks the best result,
1384+
# and writes the winning mem_id/mem_offset/mem_obj_id back onto each
1385+
# spec. For specs with no pre-assigned mem_id the algorithm defaults
1386+
# to mem_id=1; custom-pool passes may pre-assign other values (e.g. 3).
1387+
# We remap from the algo-local numbering (1-based) to the global
1388+
# position: global_mem_id = (local_mem_id - 1) + base_mem_id.
1389+
for spec in device_specs:
1390+
if spec.mem_id is not None:
1391+
spec.mem_id = (spec.mem_id - 1) + base_mem_id
1392+
1393+
# Ensure backward compatibility: at least [0, 0] when no specs exist
1394+
if len(global_bufsizes) < 2:
1395+
global_bufsizes.append(0)
1396+
buffer_devices.append(_CPU_KEY)
1397+
1398+
# Insert free calls and build device buffer mapping
1399+
insert_calls_to_free(graph_module, all_specs)
1400+
1401+
non_const_buffer_device = _build_non_const_buffer_device(buffer_devices)
1402+
graph_module.meta["non_const_buffer_sizes"] = global_bufsizes
1403+
if non_const_buffer_device is not None:
1404+
graph_module.meta["non_const_buffer_device"] = non_const_buffer_device
1405+
return global_bufsizes

exir/passes/memory_planning_pass.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ def __init__(
153153
alloc_mutable_buffers: bool = True,
154154
share_mutable_buffers: bool = False,
155155
alignment: int = ALIGNMENT,
156+
enable_non_cpu_memory_planning: bool = False,
156157
) -> None:
157158
r"""
158159
alloc_graph_input/alloc_graph_output will have 4 different combinations
@@ -173,6 +174,7 @@ def __init__(
173174
self.alloc_mutable_buffers = alloc_mutable_buffers
174175
self.share_mutable_buffers = share_mutable_buffers
175176
self.alignment = alignment
177+
self.enable_non_cpu_memory_planning = enable_non_cpu_memory_planning
176178
self.state = _MemoryPlanningState()
177179

178180
def _set_alloc_node_spec(self, graph_module: torch.fx.GraphModule) -> None:
@@ -250,6 +252,7 @@ def run(
250252
# If mutable buffers are shared, then do not allocate them in the
251253
# main memory planning algo; they are allocated in run_multimethod.
252254
self.alloc_mutable_buffers and not self.share_mutable_buffers,
255+
self.enable_non_cpu_memory_planning,
253256
)
254257

255258
if self.share_mutable_buffers and graph_signature is not None:

exir/program/_program.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1798,6 +1798,12 @@ def to_executorch( # noqa (FLAKE8) C901
17981798
)
17991799
else:
18001800
memory_planning_pass = config.memory_planning_pass
1801+
# Propagate enable_non_cpu_memory_planning from the top-level config
1802+
# to the pass instance so that device-aware partitioning is applied.
1803+
if hasattr(memory_planning_pass, "enable_non_cpu_memory_planning"):
1804+
memory_planning_pass.enable_non_cpu_memory_planning = (
1805+
config.enable_non_cpu_memory_planning
1806+
)
18011807
# TODO(jakeszwe): Follow up with compiler on if the deepcopy is necessary and if so how to make it work
18021808
if hasattr(memory_planning_pass, "run"):
18031809
new_gm_res = memory_planning_pass.run(new_gm, new_signature)

0 commit comments

Comments
 (0)