Skip to content

Commit 7ff22e2

Browse files
committed
[ET Device Support] Define AOT device copy ops registry
Define et_copy._h2d_copy and et_copy._d2h_copy custom ops for explicit host-to-device and device-to-host data transfer at delegate boundaries. Follows dim_order_ops_registry.py pattern: - Defines functional and out variants for both ops - Tracing implementations: clone/copy (CPU-only during tracing) - Registered via torch.library Differential Revision: [D99636779](https://our.internmc.facebook.com/intern/diff/D99636779/) [ghstack-poisoned]
1 parent 6c5df5d commit 7ff22e2

4 files changed

Lines changed: 150 additions & 0 deletions

File tree

exir/passes/BUCK

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,14 @@ fbcode_target(_kind = runtime.python_library,
369369
],
370370
)
371371

372+
fbcode_target(_kind = runtime.python_library,
373+
name = "device_copy_ops_registry",
374+
srcs = ["_device_copy_ops_registry.py"],
375+
deps = [
376+
"//caffe2:torch",
377+
],
378+
)
379+
372380
fbcode_target(_kind = runtime.python_library,
373381
name = "memory_format_ops_pass",
374382
srcs = [
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
Registry for device copy ops used to insert explicit H2D (host-to-device)
9+
and D2H (device-to-host) data transfer operations at delegate boundaries.
10+
11+
These ops are inserted by PropagateDevicePass when enable_non_cpu_memory_planning
12+
is True, making the graph functional by explicitly transferring data between
13+
CPU and device memory.
14+
15+
Follows the same registration pattern as dim_order_ops_registry.py.
16+
"""
17+
18+
import torch
19+
from torch.library import impl, Library
20+
21+
lib = Library("et_copy", "DEF")
22+
23+
# _h2d_copy: copies a CPU tensor to device memory.
24+
# At tracing time, this is a clone (both on CPU). At runtime, the out tensor
25+
# is memory-planned on device, and the kernel calls
26+
# DeviceAllocator::copy_host_to_device.
27+
lib.define("_h2d_copy(Tensor self) -> Tensor")
28+
lib.define("_h2d_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)")
29+
30+
# _d2h_copy: copies a device tensor to CPU memory.
31+
# At tracing time, this is a clone (both on CPU). At runtime, the self tensor
32+
# has device memory, and the kernel calls DeviceAllocator::copy_device_to_host.
33+
lib.define("_d2h_copy(Tensor self) -> Tensor")
34+
lib.define("_d2h_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)")
35+
36+
37+
@impl(lib, "_h2d_copy", "CompositeImplicitAutograd")
38+
def _h2d_copy_impl(self: torch.Tensor) -> torch.Tensor:
39+
# During tracing, both tensors are on CPU. Just clone to represent the transfer.
40+
return self.clone()
41+
42+
43+
@impl(lib, "_h2d_copy.out", "CompositeImplicitAutograd")
44+
def _h2d_copy_out_impl(self: torch.Tensor, *, out: torch.Tensor) -> torch.Tensor:
45+
out.copy_(self)
46+
return out
47+
48+
49+
@impl(lib, "_d2h_copy", "CompositeImplicitAutograd")
50+
def _d2h_copy_impl(self: torch.Tensor) -> torch.Tensor:
51+
# During tracing, both tensors are on CPU. Just clone to represent the transfer.
52+
return self.clone()
53+
54+
55+
@impl(lib, "_d2h_copy.out", "CompositeImplicitAutograd")
56+
def _d2h_copy_out_impl(self: torch.Tensor, *, out: torch.Tensor) -> torch.Tensor:
57+
out.copy_(self)
58+
return out

exir/tests/TARGETS

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -504,3 +504,14 @@ python_unittest(
504504
"//executorch/exir/passes:propagate_device_pass",
505505
],
506506
)
507+
508+
python_unittest(
509+
name = "device_copy_ops",
510+
srcs = [
511+
"test_device_copy_ops.py",
512+
],
513+
deps = [
514+
"//caffe2:torch",
515+
"//executorch/exir/passes:device_copy_ops_registry",
516+
],
517+
)

exir/tests/test_device_copy_ops.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
# Import the registry to register the ops
10+
import executorch.exir.passes._device_copy_ops_registry # noqa: F401
11+
12+
import torch
13+
14+
15+
class DeviceCopyOpsRegistryTest(unittest.TestCase):
16+
"""Tests that et_copy._h2d_copy and et_copy._d2h_copy ops are correctly
17+
registered and produce expected outputs during tracing (CPU-only)."""
18+
19+
def test_h2d_copy_functional(self):
20+
"""_h2d_copy should return a clone of the input tensor."""
21+
x = torch.randn(2, 3)
22+
result = torch.ops.et_copy._h2d_copy(x)
23+
self.assertEqual(result.shape, x.shape)
24+
self.assertEqual(result.dtype, x.dtype)
25+
self.assertTrue(torch.equal(result, x))
26+
# Should be a new tensor, not the same object
27+
self.assertFalse(result.data_ptr() == x.data_ptr())
28+
29+
def test_d2h_copy_functional(self):
30+
"""_d2h_copy should return a clone of the input tensor."""
31+
x = torch.randn(4, 5)
32+
result = torch.ops.et_copy._d2h_copy(x)
33+
self.assertEqual(result.shape, x.shape)
34+
self.assertEqual(result.dtype, x.dtype)
35+
self.assertTrue(torch.equal(result, x))
36+
self.assertFalse(result.data_ptr() == x.data_ptr())
37+
38+
def test_h2d_copy_out_variant(self):
39+
"""_h2d_copy.out should copy data into the provided out tensor."""
40+
x = torch.randn(3, 3)
41+
out = torch.empty(3, 3)
42+
result = torch.ops.et_copy._h2d_copy.out(x, out=out)
43+
self.assertTrue(result is out)
44+
self.assertTrue(torch.equal(out, x))
45+
46+
def test_d2h_copy_out_variant(self):
47+
"""_d2h_copy.out should copy data into the provided out tensor."""
48+
x = torch.randn(2, 4)
49+
out = torch.empty(2, 4)
50+
result = torch.ops.et_copy._d2h_copy.out(x, out=out)
51+
self.assertTrue(result is out)
52+
self.assertTrue(torch.equal(out, x))
53+
54+
def test_h2d_copy_preserves_dtype(self):
55+
"""_h2d_copy should work with various dtypes."""
56+
for dtype in [torch.float32, torch.float16, torch.int32, torch.int64]:
57+
x = torch.ones(2, 2, dtype=dtype)
58+
result = torch.ops.et_copy._h2d_copy(x)
59+
self.assertEqual(result.dtype, dtype)
60+
self.assertTrue(torch.equal(result, x))
61+
62+
def test_h2d_copy_scalar_tensor(self):
63+
"""_h2d_copy should handle 0-dim tensors."""
64+
x = torch.tensor(3.14)
65+
result = torch.ops.et_copy._h2d_copy(x)
66+
self.assertEqual(result.shape, torch.Size([]))
67+
self.assertTrue(torch.equal(result, x))
68+
69+
def test_d2h_copy_empty_tensor(self):
70+
"""_d2h_copy should handle empty tensors."""
71+
x = torch.empty(0, 3)
72+
result = torch.ops.et_copy._d2h_copy(x)
73+
self.assertEqual(result.shape, torch.Size([0, 3]))

0 commit comments

Comments
 (0)