Skip to content

Commit 459cf12

Browse files
committed
[ET Device Support] Device-aware memory planning: separate buffers per device type
Extends memory planning to separate device tensors from CPU tensors into distinct memory buffers. Non-CPU TensorSpecs (e.g., CUDA) are pre-assigned device-specific mem_ids before the greedy/naive algorithm runs, ensuring they get planned into independent memory buffers that never share space with CPU tensors. Differential Revision: [D97447105](https://our.internmc.facebook.com/intern/diff/D97447105/) ghstack-source-id: 355133801 Pull Request resolved: #18375
1 parent 92a1850 commit 459cf12

5 files changed

Lines changed: 273 additions & 29 deletions

File tree

exir/capture/_config.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,5 +115,11 @@ class ExecutorchBackendConfig:
115115
# If set to true, we run quant fusion and constant propagation passes
116116
do_quant_fusion_and_const_prop: bool = False
117117

118-
# Experimental: If set to true, we run a pass to reinplace ops in the graph.
118+
# If set to true, we run a pass to reinplace ops in the graph.
119119
run_reinplace_pass: bool = False
120+
121+
# When True, memory planning partitions specs by device and runs the
122+
# algorithm independently per device, producing separate buffers for CPU
123+
# vs. accelerator memory. Default False preserves the legacy behavior
124+
# where all tensors are planned into CPU memory regardless of device.
125+
enable_non_cpu_memory_planning: bool = False

exir/memory_planning.py

Lines changed: 88 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import torch
2929
from executorch.exir import memory
3030
from executorch.exir.control_flow import while_loop as exir_while
31+
from executorch.exir.schema import DeviceType, NonConstBufferDevice
3132
from executorch.exir.delegate import executorch_call_delegate
3233
from executorch.exir.error import internal_assert, InternalError
3334
from executorch.exir.operator.convert import is_inplace_variant, is_out_variant
@@ -1211,10 +1212,19 @@ def apply_algo(
12111212
alloc_graph_input: bool = True,
12121213
alloc_graph_output: bool = True,
12131214
alloc_mutable_buffers: bool = True,
1215+
enable_non_cpu_memory_planning: bool = False,
12141216
) -> list[int]:
12151217
"""
12161218
Recursively apply algo to graph_module and its submodules for control flow.
12171219
1220+
Partitions specs by device type and device idx, and runs the memory planning
1221+
algorithm independently per device, then merges results into separate buffers.
1222+
This ensures device memory and CPU memory are never mixed.
1223+
1224+
When enable_non_cpu_memory_planning is False (default), all specs are planned
1225+
into a single CPU memory pool regardless of their device attribute. This
1226+
preserves the legacy behavior. Set to True to enable per-device partitioning.
1227+
12181228
Algo implementation should handle one of two meta entries for submodules:
12191229
1. input_mem_buffer_sizes: List of int offset bytes. Memory allocated by
12201230
`algo` should start at the offset specified by this list;
@@ -1229,18 +1239,19 @@ def apply_algo(
12291239
`operand` arg. The memory for operands is unused.
12301240
"""
12311241
# Extract the nodes and their lifespans from the graph_module
1232-
# Difficult to just filter the list of specs returned by this due to
1233-
# how we flag trainable weights.
12341242
_ = update_all_tensors_lifetime(graph_module, graph_signature)
12351243

1236-
# Filter specs based on alloc_graph_input and alloc_graph_output
1237-
specs = collect_specs_from_nodes(
1238-
graph_module.graph.nodes,
1239-
graph_signature,
1240-
do_assertion=False,
1241-
ignore_graph_input=not alloc_graph_input,
1242-
ignore_graph_output=not alloc_graph_output,
1243-
ignore_mutable_buffers=not alloc_mutable_buffers,
1244+
# Collect and materialize specs into a set so we can iterate multiple
1245+
# times and partition by device.
1246+
all_specs: set[TensorSpec] = set(
1247+
collect_specs_from_nodes(
1248+
graph_module.graph.nodes,
1249+
graph_signature,
1250+
do_assertion=False,
1251+
ignore_graph_input=not alloc_graph_input,
1252+
ignore_graph_output=not alloc_graph_output,
1253+
ignore_mutable_buffers=not alloc_mutable_buffers,
1254+
)
12441255
)
12451256

12461257
# Get temporary specs for submodules to set aside space during execution
@@ -1249,29 +1260,78 @@ def apply_algo(
12491260
algo, graph_module, alignment, graph_signature
12501261
)
12511262

1252-
# Update `input_mem_buffer_sizes` in graph_module. This will allow existing
1253-
# algos to work using `input_mem_buffer_sizes` or use
1254-
# `non_const_buffer_sizes` directly.
1255-
# pyre-ignore[16]: `torch.fx.GraphModule` has no attribute `input_mem_buffer_sizes`.
1256-
graph_module.input_mem_buffer_sizes = submodule_bufsizes
1257-
12581263
# Get extra padding for XNNPACK if needed
12591264
extra_padding = 0
12601265
if _contains_xnnpack_delegate(graph_module):
12611266
extra_padding = 64
12621267

1263-
# Pass the filtered specs to the algorithm
1264-
bufsizes: list[int] = algo(
1265-
alignment,
1266-
specs,
1267-
graph_module,
1268-
graph_signature,
1269-
extra_padding,
1268+
# 1. Partition specs by device
1269+
specs_by_device: dict[DeviceType, set[TensorSpec]] = defaultdict(set)
1270+
if enable_non_cpu_memory_planning:
1271+
for spec in all_specs:
1272+
specs_by_device[spec.device].add(spec)
1273+
else:
1274+
# Legacy behavior: all specs planned into CPU memory regardless of device
1275+
specs_by_device[DeviceType.CPU] = all_specs
1276+
1277+
# 2. Plan each device independently
1278+
global_bufsizes: list[int] = [0] # index 0 reserved for constants
1279+
buffer_device_types: list[DeviceType] = [DeviceType.CPU]
1280+
1281+
# Process CPU first (if present), then other devices sorted by enum value
1282+
device_order = sorted(
1283+
specs_by_device.keys(),
1284+
key=lambda d: (d != DeviceType.CPU, d.value),
12701285
)
12711286

1272-
# pyre-ignore[6]: Incompatible parameter type [6]
1273-
# In call `insert_calls_to_free`, for 2nd positional argument, expected `Set[TensorSpec]` but got `Iterable[TensorSpec]`
1274-
insert_calls_to_free(graph_module, specs)
1287+
for device_type in device_order:
1288+
device_specs = specs_by_device[device_type]
12751289

1276-
graph_module.meta.update({"non_const_buffer_sizes": bufsizes})
1277-
return bufsizes
1290+
# Only apply submodule pre-allocation for CPU specs; device buffers
1291+
# do not share memory space with CPU submodule arenas.
1292+
# pyre-ignore[16]: `torch.fx.GraphModule` has no attribute `input_mem_buffer_sizes`.
1293+
graph_module.input_mem_buffer_sizes = (
1294+
submodule_bufsizes if device_type == DeviceType.CPU else []
1295+
)
1296+
1297+
# Run algorithm independently on this device's specs
1298+
device_bufsizes = algo(
1299+
alignment, device_specs, graph_module, graph_signature, extra_padding
1300+
)
1301+
1302+
# Calculate base mem_id in global space
1303+
base_mem_id = len(global_bufsizes)
1304+
1305+
# Append buffer sizes (skip index 0 which is constants placeholder)
1306+
global_bufsizes.extend(device_bufsizes[1:])
1307+
1308+
# Track device type for each new buffer slot
1309+
for _ in device_bufsizes[1:]:
1310+
buffer_device_types.append(device_type)
1311+
1312+
# Remap spec mem_ids from algo-local to global.
1313+
# The algorithm assigns mem_id starting from 1; remap to global position.
1314+
for spec in device_specs:
1315+
if spec.mem_id is not None:
1316+
spec.mem_id = (spec.mem_id - 1) + base_mem_id
1317+
1318+
# Ensure backward compatibility: at least [0, 0] when no specs exist
1319+
if len(global_bufsizes) < 2:
1320+
global_bufsizes.append(0)
1321+
buffer_device_types.append(DeviceType.CPU)
1322+
1323+
# 3. Insert free calls and build device buffer mapping
1324+
insert_calls_to_free(graph_module, all_specs)
1325+
1326+
has_device_buffers = any(dt != DeviceType.CPU for dt in buffer_device_types)
1327+
non_const_buffer_device: Optional[list[NonConstBufferDevice]] = None
1328+
if has_device_buffers:
1329+
non_const_buffer_device = [
1330+
NonConstBufferDevice(buffer_idx=i, device_type=dt, device_index=0)
1331+
for i, dt in enumerate(buffer_device_types)
1332+
]
1333+
1334+
graph_module.meta["non_const_buffer_sizes"] = global_bufsizes
1335+
if non_const_buffer_device is not None:
1336+
graph_module.meta["non_const_buffer_device"] = non_const_buffer_device
1337+
return global_bufsizes

exir/passes/memory_planning_pass.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ def __init__(
153153
alloc_mutable_buffers: bool = True,
154154
share_mutable_buffers: bool = False,
155155
alignment: int = ALIGNMENT,
156+
enable_non_cpu_memory_planning: bool = False,
156157
) -> None:
157158
r"""
158159
alloc_graph_input/alloc_graph_output will have 4 different combinations
@@ -173,6 +174,7 @@ def __init__(
173174
self.alloc_mutable_buffers = alloc_mutable_buffers
174175
self.share_mutable_buffers = share_mutable_buffers
175176
self.alignment = alignment
177+
self.enable_non_cpu_memory_planning = enable_non_cpu_memory_planning
176178
self.state = _MemoryPlanningState()
177179

178180
def _set_alloc_node_spec(self, graph_module: torch.fx.GraphModule) -> None:
@@ -250,6 +252,7 @@ def run(
250252
# If mutable buffers are shared, then do not allocate them in the
251253
# main memory planning algo; they are allocated in run_multimethod.
252254
self.alloc_mutable_buffers and not self.share_mutable_buffers,
255+
self.enable_non_cpu_memory_planning,
253256
)
254257

255258
if self.share_mutable_buffers and graph_signature is not None:

exir/program/_program.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1792,6 +1792,12 @@ def to_executorch( # noqa (FLAKE8) C901
17921792
else:
17931793
memory_planning_pass = config.memory_planning_pass
17941794
# TODO(jakeszwe): Follow up with compiler on if the deepcopy is necessary and if so how to make it work
1795+
# Propagate enable_non_cpu_memory_planning from the top-level config
1796+
# to the pass instance so that device-aware partitioning is applied.
1797+
if hasattr(memory_planning_pass, "enable_non_cpu_memory_planning"):
1798+
memory_planning_pass.enable_non_cpu_memory_planning = (
1799+
config.enable_non_cpu_memory_planning
1800+
)
17951801
if hasattr(memory_planning_pass, "run"):
17961802
new_gm_res = memory_planning_pass.run(new_gm, new_signature)
17971803
else:

exir/tests/test_memory_planning.py

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
from executorch.exir.dialects._ops import ops as exir_ops
3030
from executorch.exir.memory_planning import (
3131
_do_user_inputs_exist,
32+
apply_algo,
33+
collect_specs_from_nodes,
3234
filter_nodes,
3335
get_node_tensor_specs,
3436
greedy,
@@ -45,6 +47,7 @@
4547
ToOutVarPass,
4648
)
4749
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
50+
from executorch.exir.schema import DeviceType
4851
from executorch.exir.tensor import TensorSpec
4952
from functorch.experimental.control_flow import map as torch_map
5053
from parameterized import parameterized
@@ -1259,3 +1262,169 @@ def reset(self, k_zeros: torch.Tensor, v_zeros: torch.Tensor) -> None:
12591262
self.assertEqual(v_cache[0].val.allocation_info.memory_id, 2)
12601263
self.assertEqual(v_cache[0].val.allocation_info.memory_offset_low, 256)
12611264
self.assertEqual(v_cache[0].val.allocation_info.memory_offset_high, 0)
1265+
1266+
1267+
class TestDeviceAwareMemoryPlanning(unittest.TestCase):
1268+
"""Tests for per-device memory planning (separate buffers per device type)."""
1269+
1270+
def _prepare_model(
1271+
self,
1272+
) -> Tuple[GraphModule, ExportGraphSignature]:
1273+
"""Prepare ToyModelForMemPlanning through SpecPropPass + ToOutVarPass."""
1274+
model = ToyModelForMemPlanning()
1275+
inputs = model.get_random_inputs()
1276+
edge = to_edge(export(model, inputs, strict=True))
1277+
gm = edge.exported_program().graph_module
1278+
gs = edge.exported_program().graph_signature
1279+
gm = PassManager(passes=[SpecPropPass(), ToOutVarPass()])(gm).graph_module
1280+
return gm, gs
1281+
1282+
def _get_planned_specs(
1283+
self,
1284+
gm: GraphModule,
1285+
gs: ExportGraphSignature,
1286+
) -> list[TensorSpec]:
1287+
"""Get the unique set of specs that apply_algo would plan."""
1288+
return list(
1289+
collect_specs_from_nodes(
1290+
gm.graph.nodes,
1291+
gs,
1292+
do_assertion=False,
1293+
ignore_graph_input=False,
1294+
ignore_graph_output=False,
1295+
ignore_mutable_buffers=False,
1296+
)
1297+
)
1298+
1299+
def test_cpu_only_unchanged(self) -> None:
1300+
"""CPU-only specs produce bufsizes = [0, X] with no device metadata."""
1301+
gm, gs = self._prepare_model()
1302+
1303+
algo = MemoryPlanningAlgorithmSuite(algo_list=[greedy])
1304+
bufsizes = apply_algo(
1305+
algo, gm, 16, gs, enable_non_cpu_memory_planning=True
1306+
)
1307+
1308+
# The CUDA spec is the only tensor in its buffer
1309+
self.assertEqual(bufsizes[0], 0) # constants
1310+
self.assertGreater(bufsizes[1], 0) # CPU activations
1311+
self.assertNotIn("non_const_buffer_device", gm.meta)
1312+
1313+
def test_all_cuda_no_wasted_slots(self) -> None:
1314+
"""CUDA-only specs produce [0, X] with CUDA at buffer index 1."""
1315+
gm, gs = self._prepare_model()
1316+
specs = self._get_planned_specs(gm, gs)
1317+
for spec in specs:
1318+
spec.device = DeviceType.CUDA
1319+
1320+
algo = MemoryPlanningAlgorithmSuite(algo_list=[greedy])
1321+
bufsizes = apply_algo(algo, gm, 16, gs, enable_non_cpu_memory_planning=True)
1322+
1323+
# [0, cuda_size] — no wasted CPU buffer slot
1324+
self.assertEqual(len(bufsizes), 2)
1325+
self.assertEqual(bufsizes[0], 0)
1326+
self.assertGreater(bufsizes[1], 0)
1327+
# Device mapping should be present
1328+
self.assertIn("non_const_buffer_device", gm.meta)
1329+
device_map = gm.meta["non_const_buffer_device"]
1330+
self.assertEqual(len(device_map), 2)
1331+
self.assertEqual(device_map[0].device_type, DeviceType.CPU) # constants
1332+
self.assertEqual(device_map[1].device_type, DeviceType.CUDA)
1333+
1334+
def test_mixed_cpu_cuda_separate_buffers(self) -> None:
1335+
"""CPU specs at mem_id=1, CUDA specs at mem_id=2, separate sizes."""
1336+
gm, gs = self._prepare_model()
1337+
specs = self._get_planned_specs(gm, gs)
1338+
1339+
# Set second half of specs to CUDA
1340+
mid = len(specs) // 2
1341+
self.assertGreater(mid, 0)
1342+
cpu_specs = specs[:mid]
1343+
cuda_specs = specs[mid:]
1344+
for spec in cuda_specs:
1345+
spec.device = DeviceType.CUDA
1346+
1347+
algo = MemoryPlanningAlgorithmSuite(algo_list=[greedy])
1348+
bufsizes = apply_algo(algo, gm, 16, gs, enable_non_cpu_memory_planning=True)
1349+
1350+
# [constants, cpu_activations, cuda_activations]
1351+
self.assertEqual(len(bufsizes), 3)
1352+
self.assertEqual(bufsizes[0], 0)
1353+
self.assertGreater(bufsizes[1], 0)
1354+
self.assertGreater(bufsizes[2], 0)
1355+
1356+
# CPU specs should have mem_id=1, CUDA specs should have mem_id=2
1357+
for spec in cpu_specs:
1358+
self.assertEqual(spec.mem_id, 1, f"CPU spec has wrong mem_id: {spec.mem_id}")
1359+
for spec in cuda_specs:
1360+
self.assertEqual(spec.mem_id, 2, f"CUDA spec has wrong mem_id: {spec.mem_id}")
1361+
1362+
def test_mem_offset_correct_after_remap(self) -> None:
1363+
"""After remapping, mem_offset is relative to its own buffer."""
1364+
gm, gs = self._prepare_model()
1365+
specs = self._get_planned_specs(gm, gs)
1366+
1367+
# Set the last spec to CUDA (sole CUDA tensor)
1368+
cuda_spec = specs[-1]
1369+
cuda_spec.device = DeviceType.CUDA
1370+
1371+
algo = MemoryPlanningAlgorithmSuite(algo_list=[greedy])
1372+
bufsizes = apply_algo(
1373+
algo, gm, 16, gs, enable_non_cpu_memory_planning=True
1374+
)
1375+
1376+
# The CUDA spec is the only tensor in its buffer, so offset should be 0
1377+
self.assertEqual(cuda_spec.mem_offset, 0)
1378+
# The CUDA buffer should fit exactly this tensor
1379+
cuda_mem_id = cuda_spec.mem_id
1380+
self.assertIsNotNone(cuda_mem_id)
1381+
assert cuda_mem_id is not None
1382+
self.assertGreaterEqual(bufsizes[cuda_mem_id], cuda_spec.allocated_memory)
1383+
1384+
def test_no_cross_device_memory_sharing(self) -> None:
1385+
"""Specs on different devices never share buffers, regardless of lifetime."""
1386+
gm, gs = self._prepare_model()
1387+
specs = self._get_planned_specs(gm, gs)
1388+
self.assertGreaterEqual(len(specs), 2)
1389+
1390+
# Assign alternating specs to CUDA to ensure some pairs have
1391+
# non-overlapping lifetimes (which greedy would normally share).
1392+
for i, spec in enumerate(specs):
1393+
if i % 2 == 0:
1394+
spec.device = DeviceType.CUDA
1395+
1396+
algo = MemoryPlanningAlgorithmSuite(algo_list=[greedy])
1397+
apply_algo(algo, gm, 16, gs, enable_non_cpu_memory_planning=True)
1398+
1399+
# Verify CPU and CUDA specs have disjoint mem_ids
1400+
cpu_mem_ids: set[int] = set()
1401+
cuda_mem_ids: set[int] = set()
1402+
for i, spec in enumerate(specs):
1403+
if spec.mem_id is not None:
1404+
if i % 2 == 0:
1405+
cuda_mem_ids.add(spec.mem_id)
1406+
else:
1407+
cpu_mem_ids.add(spec.mem_id)
1408+
1409+
self.assertTrue(
1410+
cpu_mem_ids.isdisjoint(cuda_mem_ids),
1411+
f"CPU {cpu_mem_ids} and CUDA {cuda_mem_ids} should not share buffers",
1412+
)
1413+
1414+
def test_disabled_falls_back_to_cpu(self) -> None:
1415+
"""With enable_non_cpu_memory_planning=False (default), CUDA specs are
1416+
planned into CPU memory — no device-specific buffers are created."""
1417+
gm, gs = self._prepare_model()
1418+
specs = self._get_planned_specs(gm, gs)
1419+
for spec in specs:
1420+
spec.device = DeviceType.CUDA
1421+
1422+
algo = MemoryPlanningAlgorithmSuite(algo_list=[greedy])
1423+
# Default: enable_non_cpu_memory_planning=False
1424+
bufsizes = apply_algo(algo, gm, 16, gs)
1425+
1426+
# All specs planned into a single CPU pool — same as CPU-only
1427+
self.assertEqual(len(bufsizes), 2)
1428+
self.assertEqual(bufsizes[0], 0)
1429+
self.assertGreater(bufsizes[1], 0)
1430+
self.assertNotIn("non_const_buffer_device", gm.meta)

0 commit comments

Comments
 (0)