Skip to content

Commit 7d8063f

Browse files
authored
[ET Device Support] Define AOT device copy ops registry (pytorch#19748)
clone pytorch#18728 due to bot crash
1 parent c27cc5d commit 7d8063f

4 files changed

Lines changed: 150 additions & 0 deletions

File tree

exir/passes/BUCK

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,14 @@ fbcode_target(_kind = runtime.python_library,
381381
],
382382
)
383383

384+
fbcode_target(_kind = runtime.python_library,
385+
name = "device_copy_ops_registry",
386+
srcs = ["_device_copy_ops_registry.py"],
387+
deps = [
388+
"//caffe2:torch",
389+
],
390+
)
391+
384392
fbcode_target(_kind = runtime.python_library,
385393
name = "memory_format_ops_pass",
386394
srcs = [
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
Registry for device copy ops used to insert explicit H2D (host-to-device)
9+
and D2H (device-to-host) data transfer operations at delegate boundaries.
10+
11+
These ops are inserted by PropagateDevicePass when enable_non_cpu_memory_planning
12+
is True, making the graph functional by explicitly transferring data between
13+
CPU and device memory.
14+
15+
Follows the same registration pattern as dim_order_ops_registry.py.
16+
"""
17+
18+
import torch
19+
from torch.library import impl, Library
20+
21+
lib = Library("et_copy", "DEF")
22+
23+
# _h2d_copy: copies a CPU tensor to device memory.
24+
# At tracing time, this is a clone (both on CPU). At runtime, the out tensor
25+
# is memory-planned on device, and the kernel calls
26+
# DeviceAllocator::copy_host_to_device.
27+
lib.define("_h2d_copy(Tensor self) -> Tensor")
28+
lib.define("_h2d_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)")
29+
30+
# _d2h_copy: copies a device tensor to CPU memory.
31+
# At tracing time, this is a clone (both on CPU). At runtime, the self tensor
32+
# has device memory, and the kernel calls DeviceAllocator::copy_device_to_host.
33+
lib.define("_d2h_copy(Tensor self) -> Tensor")
34+
lib.define("_d2h_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)")
35+
36+
37+
@impl(lib, "_h2d_copy", "CompositeImplicitAutograd")
38+
def _h2d_copy_impl(self: torch.Tensor) -> torch.Tensor:
39+
# During tracing, both tensors are on CPU. Just clone to represent the transfer.
40+
return self.clone()
41+
42+
43+
@impl(lib, "_h2d_copy.out", "CompositeImplicitAutograd")
44+
def _h2d_copy_out_impl(self: torch.Tensor, *, out: torch.Tensor) -> torch.Tensor:
45+
out.copy_(self)
46+
return out
47+
48+
49+
@impl(lib, "_d2h_copy", "CompositeImplicitAutograd")
50+
def _d2h_copy_impl(self: torch.Tensor) -> torch.Tensor:
51+
# During tracing, both tensors are on CPU. Just clone to represent the transfer.
52+
return self.clone()
53+
54+
55+
@impl(lib, "_d2h_copy.out", "CompositeImplicitAutograd")
56+
def _d2h_copy_out_impl(self: torch.Tensor, *, out: torch.Tensor) -> torch.Tensor:
57+
out.copy_(self)
58+
return out

exir/tests/TARGETS

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -504,3 +504,14 @@ python_unittest(
504504
"//executorch/exir/passes:propagate_device_pass",
505505
],
506506
)
507+
508+
python_unittest(
509+
name = "device_copy_ops",
510+
srcs = [
511+
"test_device_copy_ops.py",
512+
],
513+
deps = [
514+
"//caffe2:torch",
515+
"//executorch/exir/passes:device_copy_ops_registry",
516+
],
517+
)

exir/tests/test_device_copy_ops.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
# Import the registry to register the ops
10+
import executorch.exir.passes._device_copy_ops_registry # noqa: F401
11+
12+
import torch
13+
14+
15+
class DeviceCopyOpsRegistryTest(unittest.TestCase):
16+
"""Tests that et_copy._h2d_copy and et_copy._d2h_copy ops are correctly
17+
registered and produce expected outputs during tracing (CPU-only)."""
18+
19+
def test_h2d_copy_functional(self):
20+
"""_h2d_copy should return a clone of the input tensor."""
21+
x = torch.randn(2, 3)
22+
result = torch.ops.et_copy._h2d_copy(x)
23+
self.assertEqual(result.shape, x.shape)
24+
self.assertEqual(result.dtype, x.dtype)
25+
self.assertTrue(torch.equal(result, x))
26+
# Should be a new tensor, not the same object
27+
self.assertFalse(result.data_ptr() == x.data_ptr())
28+
29+
def test_d2h_copy_functional(self):
30+
"""_d2h_copy should return a clone of the input tensor."""
31+
x = torch.randn(4, 5)
32+
result = torch.ops.et_copy._d2h_copy(x)
33+
self.assertEqual(result.shape, x.shape)
34+
self.assertEqual(result.dtype, x.dtype)
35+
self.assertTrue(torch.equal(result, x))
36+
self.assertFalse(result.data_ptr() == x.data_ptr())
37+
38+
def test_h2d_copy_out_variant(self):
39+
"""_h2d_copy.out should copy data into the provided out tensor."""
40+
x = torch.randn(3, 3)
41+
out = torch.empty(3, 3)
42+
result = torch.ops.et_copy._h2d_copy.out(x, out=out)
43+
self.assertTrue(result is out)
44+
self.assertTrue(torch.equal(out, x))
45+
46+
def test_d2h_copy_out_variant(self):
47+
"""_d2h_copy.out should copy data into the provided out tensor."""
48+
x = torch.randn(2, 4)
49+
out = torch.empty(2, 4)
50+
result = torch.ops.et_copy._d2h_copy.out(x, out=out)
51+
self.assertTrue(result is out)
52+
self.assertTrue(torch.equal(out, x))
53+
54+
def test_h2d_copy_preserves_dtype(self):
55+
"""_h2d_copy should work with various dtypes."""
56+
for dtype in [torch.float32, torch.float16, torch.int32, torch.int64]:
57+
x = torch.ones(2, 2, dtype=dtype)
58+
result = torch.ops.et_copy._h2d_copy(x)
59+
self.assertEqual(result.dtype, dtype)
60+
self.assertTrue(torch.equal(result, x))
61+
62+
def test_h2d_copy_scalar_tensor(self):
63+
"""_h2d_copy should handle 0-dim tensors."""
64+
x = torch.tensor(3.14)
65+
result = torch.ops.et_copy._h2d_copy(x)
66+
self.assertEqual(result.shape, torch.Size([]))
67+
self.assertTrue(torch.equal(result, x))
68+
69+
def test_d2h_copy_empty_tensor(self):
70+
"""_d2h_copy should handle empty tensors."""
71+
x = torch.empty(0, 3)
72+
result = torch.ops.et_copy._d2h_copy(x)
73+
self.assertEqual(result.shape, torch.Size([0, 3]))

0 commit comments

Comments
 (0)