Skip to content

Commit 012af21

Browse files
committed
[ET Device Support] Annotate device attributes of CUDA backend IO tensors cuda device
Pull Request resolved: #18080 Update cuda backend partitioner to annotate its IO tensors as cuda device, and add checks in cuda backend to guarantee it works ghstack-source-id: 354699535 @exported-using-ghexport Differential Revision: [D96010436](https://our.internmc.facebook.com/intern/diff/D96010436/)
1 parent 96621aa commit 012af21

3 files changed

Lines changed: 118 additions & 1 deletion

File tree

backends/cuda/cuda_partitioner.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from executorch.backends.cuda.cuda_backend import CudaBackend # usort: skip
1111
from executorch.exir._warnings import experimental
1212
from executorch.exir.backend.compile_spec_schema import CompileSpec
13+
from executorch.exir.passes.propagate_device_pass import TARGET_DEVICE_COMPILE_SPEC_KEY
1314

1415

1516
@final
@@ -19,7 +20,34 @@
1920
class CudaPartitioner(AotiPartitioner):
2021
"""
2122
CUDA partitioner driven by AOTInductor backend.
23+
24+
This partitioner adds a target_device compile spec to enable device info
25+
propagation. The PropagateDevicePass will read this spec and mark delegate
26+
output tensors with CUDA device type, which flows through to serialization.
2227
"""
2328

24-
def __init__(self, compile_spec: List[CompileSpec]) -> None:
29+
def __init__(
30+
self,
31+
compile_spec: List[CompileSpec],
32+
device_index: int = 0,
33+
) -> None:
34+
"""
35+
Initialize the CUDA partitioner.
36+
37+
Args:
38+
compile_spec: List of compile specs for the backend.
39+
device_index: The CUDA device index (default: 0). This is used to
40+
generate the target_device compile spec (e.g., "cuda:0").
41+
"""
42+
# Add target_device compile spec for device propagation if not already present
43+
has_target_device = any(
44+
spec.key == TARGET_DEVICE_COMPILE_SPEC_KEY for spec in compile_spec
45+
)
46+
if not has_target_device:
47+
compile_spec = list(compile_spec) + [
48+
CompileSpec(
49+
TARGET_DEVICE_COMPILE_SPEC_KEY,
50+
f"cuda:{device_index}".encode("utf-8"),
51+
)
52+
]
2553
super().__init__(CudaBackend.__name__, compile_spec)

backends/cuda/runtime/cuda_backend.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,26 @@ class ET_EXPERIMENTAL CudaBackend final
403403
n_outputs,
404404
args.size())
405405

406+
// Verify device info on all memory-planned, ET-driven IO tensors.
407+
// All input and output tensors should have device_type = CUDA, which
408+
// is set during serialization by PropagateDevicePass based on the
409+
// target_device compile spec from CudaPartitioner.
410+
//
411+
// Note: At this stage, the tensor memory is still on CPU. The device_type
412+
// is metadata indicating where the tensor *should* reside. The backend
413+
// is responsible for copying data to the actual CUDA device.
414+
for (size_t i = 0; i < n_inputs + n_outputs; i++) {
415+
auto* tensor = &(args[i]->toTensor());
416+
auto device_type = tensor->unsafeGetTensorImpl()->device_type();
417+
ET_CHECK_OR_RETURN_ERROR(
418+
device_type == executorch::runtime::etensor::DeviceType::CUDA,
419+
InvalidArgument,
420+
"Tensor %zu expected device_type=CUDA (1), got %d. "
421+
"Device info may not be properly propagated from CudaPartitioner.",
422+
i,
423+
static_cast<int>(device_type));
424+
}
425+
406426
// NOTE: ExecuTorch tensors may be on CPU or GPU due to the skip-copy
407427
// optimization. We need to create GPU copies for CUDA kernel execution
408428
// using SlimTensor.

backends/cuda/tests/test_cuda_export.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,3 +325,72 @@ def test_triton_kernel_mode_off(self):
325325
edge_program_manager,
326326
"SDPA kernel export with triton_kernel_mode=OFF failed",
327327
)
328+
329+
def test_device_info_propagated_to_cuda_delegate_outputs(self):
330+
"""
331+
Test that device info is correctly propagated from export to serialization
332+
for CUDA delegate outputs.
333+
334+
This verifies the device propagation flow:
335+
1. CudaPartitioner adds target_device="cuda:0" CompileSpec
336+
2. PropagateDevicePass sets TensorSpec.device = CUDA for delegate outputs
337+
3. Emitter serializes device info into ExtraTensorInfo.device_type
338+
4. Serialized tensors have device_type = DeviceType.CUDA
339+
340+
Note: At this stage, the tensor memory is still on CPU. The CUDA backend
341+
will copy data to GPU device at runtime. Device info tagging is the first
342+
step toward full device-aware memory allocation.
343+
"""
344+
from executorch.exir import schema
345+
346+
class AddModule(torch.nn.Module):
347+
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
348+
return x + y
349+
350+
module = AddModule()
351+
module.eval()
352+
inputs = (torch.randn(2, 3), torch.randn(2, 3))
353+
354+
# Export to CUDA with full pipeline
355+
edge_program_manager = self._export_to_cuda_with_lower(module, inputs)
356+
self.assertIsNotNone(edge_program_manager, "CUDA export failed")
357+
358+
# Convert to ExecutorTorch and access the serialized program
359+
et_prog = edge_program_manager.to_executorch()
360+
program = et_prog._emitter_output.program
361+
362+
# Get the execution plan and verify delegate exists
363+
plan = program.execution_plan[0]
364+
self.assertGreater(
365+
len(plan.delegates),
366+
0,
367+
"Expected at least one delegate in the execution plan",
368+
)
369+
370+
# Count tensors by device type
371+
cpu_tensors = []
372+
cuda_tensors = []
373+
374+
for value in plan.values:
375+
if isinstance(value.val, schema.Tensor):
376+
tensor = value.val
377+
if (
378+
tensor.extra_tensor_info is not None
379+
and tensor.extra_tensor_info.device_type == schema.DeviceType.CUDA
380+
):
381+
cuda_tensors.append(tensor)
382+
else:
383+
# Either no extra_tensor_info or device_type is CPU (default)
384+
cpu_tensors.append(tensor)
385+
386+
# Both input and output tensors should be on CUDA device for now.
387+
self.assertEqual(
388+
len(cpu_tensors),
389+
0,
390+
"All tensors are on CUDA device..",
391+
)
392+
self.assertGreater(
393+
len(cuda_tensors),
394+
3,
395+
"Expected CUDA tensors for delegate outputs",
396+
)

0 commit comments

Comments
 (0)